[
  {
    "path": ".travis.yml",
    "content": "language: python\npython:\n  - \"3.6\"\nenv:\n  matrix:\n    - TF_VERSION=\"1.15\"\ninstall:\n  - pip install -q \"tensorflow==$TF_VERSION\"\n  - pip install -q .[tests]\n  # Make sure we have the latest version of numpy - avoid problems we were\n  # seeing with Python 3\n  - pip install -q -U numpy\nscript:\n  # Check import\n  - python -c \"import kfac; print(kfac.LayerCollection.__name__)\"\n\n  # Run tests\n  - pytest\n\ngit:\n  depth: 3\n"
  },
  {
    "path": "AUTHORS",
    "content": "# This is the official list of TensorFlow authors for copyright purposes.\n# This file is distinct from the CONTRIBUTORS files.\n# See the latter for an explanation.\n\n# Names should be added to this file as:\n# Name or Organization <email address>\n# The email address is not required for organizations.\n\nGoogle Inc.\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2019 The TensorFlow Authors.  All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2019, The TensorFlow Authors.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# K-FAC: Kronecker-Factored Approximate Curvature\n\n[![Travis](https://img.shields.io/travis/tensorflow/kfac.svg)](https://travis-ci.org/tensorflow/kfac)\n\n**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an\napproximate second-order optimization method, in TensorFlow.\n\n[kfac-paper]: https://arxiv.org/abs/1503.05671\n\n## Installation\n\n`kfac` is compatible with Python 2 and 3 and can be installed directly via\n`pip`,\n\n```shell\n# Assumes tensorflow or tensorflow-gpu installed\n$ pip install kfac\n\n# Installs with tensorflow-gpu requirement\n$ pip install 'kfac[tensorflow_gpu]'\n\n# Installs with tensorflow (cpu) requirement\n$ pip install 'kfac[tensorflow]'\n```\n\n## KFAC DOCS\n\nPlease check [KFAC docs][kfac_docs] for a detailed description with examples\nof how to use KFAC. Check the [Keras KFAC docs][keras_docs] for information on\nusing KFAC with Keras.\n\n[kfac_docs]: https://github.com/tensorflow/kfac/tree/master/docs/index.md\n[keras_docs]: https://github.com/tensorflow/kfac/tree/master/kfac/python/keras/README.md\n"
  },
  {
    "path": "docs/applications.md",
    "content": "Coming Soon..\n"
  },
  {
    "path": "docs/contact.md",
    "content": "Topic                | Contact\n-------------------- | ---------------------\n**Questions**        | kfac-users@google.com\n**Development Team** | kfac-dev@google.com\n\nPrimary contacts:\n\n*   James Martens (jamesmartens@google.com)\n\nContributors (past and present):\n\n*   Alok Aggarwal (aloka@google.com)\n*   Daniel Duckworth (duckworthd@google.com)\n*   David Pfau (pfau@google.com)\n*   Dominik Grewe (dominikg@google.com)\n*   Guodong Zhang (gdzhang@google.com)\n*   James Keeling (jtkeeling@google.com)\n*   James Martens (jamesmartens@google.com)\n*   Jimmy Ba (jba@cs.toronto.edu)\n*   Lala Li (lala@google.com)\n*   Matthew Johnson (mattjj@google.com)\n*   Nicholas Vadivelu (nvadivelu@google.com)\n*   Noah Siegel (siegeln@google.com)\n*   Olga Wichrowskaa (olganw@google.com)\n*   Rishabh Kabra (rkabra@google.com)\n*   Roger Grosse (rgrosse@google.com)\n*   Soham De (sohamde@google.com)\n*   Tamas Berghammer (tberghammer@google.com)\n*   Vikram Tankasali (tvikram@google.com)\n*   Zachary Nado (znado@google.com)\n"
  },
  {
    "path": "docs/examples/auto_damp.md",
    "content": "# Automatic tuning of damping parameter.\n\n## Table of Contents\n\n*   [1. Cached Reader](#1-cached-reader)\n*   [2. Build optimizer and set damping parameters](#2-build-optimizer-and-set-damping-parameters)\n*   [TIPS:](#tips)\n    <br>\n\nThe [KFAC damping parameter][kfac_damp] can be auto tuned using\nLevenberg-Marquardt (LM) algorithm. For a detailed description of the algorithm\nrefer to `Section 6` of the [KFAC Paper][kfac_paper]. Note this is still a\nheuristic and may not always produce optimal results. It can be better or worse\nthan a carefully tuned fixed value, depending on the problem.\n\n[kfac_paper]: https://arxiv.org/pdf/1503.05671.pdf\n[kfac_damp]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md\n\n**Example code**:\nhttps://github.com/tensorflow/kfac/tree/master/kfac/examples/autoencoder_mnist.py\n\nUsing this method to auto tune damping requires changes to the basic KFAC\ntraining script, which are described below. We only highlight additional steps\nrequired vs training with a fixed damping value (as in the [Convnet\nexample][convexamplesec])\n\n[convexamplesec]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md\n\n## 1. Cached Reader\n\nWrap the dataset into `CachedReader`. This allows us to access previous batch of\ndata.\n\n```python\n    cached_reader = data_reader.CachedDataReader(dataset, max_batch_size)\n    minibatch = cached_reader(batch_size)\n```\n\n## 2. Build optimizer and set damping parameters\n\n```python\n  optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n      learning_rate=1.0,\n      damping=150.,\n      momentum=0.95,\n      layer_collection=layer_collection,\n      batch_size=batch_size,\n      adapt_damping=True,\n      prev_train_batch=cached_reader.cached_batch,\n      is_chief=True,\n      loss_fn=loss_fn,\n      damping_adaptation_decay=0.95,\n      damping_adaptation_interval=FLAGS.damping_adaptation_interval,\n  )\n  train_op = optimizer.minimize(loss, global_step=global_step)\n```\n\n## TIPS:\n\n1.  Damping can also be tuned using Population based training ([PBT][PBT_link]).\n    In our observations PBT works on par with auto tuning using LM algorithm,\n    although is obviously more computationally expensive. However if you are\n    already doing PBT for other hyperparams then consider tuning damping using\n    PBT as well.\n\n[PBT_link]: https://arxiv.org/abs/1711.09846\n"
  },
  {
    "path": "docs/examples/convolutional.md",
    "content": "# Convolutional\n\n## Table of Contents\n\n*   [Build the Model](#build-the-model)\n*   [Register the layers and loss](#register-the-layers-and-loss)\n*   [Build the optimizer](#build-the-optimizer)\n*   [Fit the model](#fit-the-model)\n*   [TIPS](#tips)\n    <br>\n\nK-FAC needs to know about the structure of your model in order to effectively\noptimize it. In particular, it needs to know about:\n\n1.  Each convolutional and feed forward layer's inputs and outputs.\n1.  All of the model parameters.\n1.  The type of the loss function and its inputs.\n\nLet's explore how we can use K-FAC to solve digit classification with MNIST\nusing a simple convolutional model. In the following example we will illustrate\nhow to use `PeriodicInvCovUpdateOpt` which is a subclass of `KfacOptimizer`.\n`PeriodicInvCovUpdateOpt` handles placement and execution of covariance and\ninverse ops. We will also illustrate how to register the layers both manually\nand automatically using the graph scanner.\n\n**Code**:\nhttps://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet_mnist_single_main.py\n\n## Build the Model\n\nFirst, we begin by defining a model. In this case, we'll load MNIST and\nconstruct a 5-layer ConvNet. The model has 2 Conv/MaxPool pairs and a final\nlinear layer. If we are registering the layers manually we need to keep the\ninputs and outputs and parameters (weights & bias) around, which is illustrated\nhere.\n\n```python\n  # Load a dataset.\n  examples, labels = mnist.load_mnist(\n      data_dir,\n      num_epochs=num_epochs,\n      batch_size=128,\n      use_fake_data=use_fake_data,\n      flatten_images=False)\n\n  # Build a ConvNet.\n  pre0, act0, params0 = conv_layer(\n      layer_id=0, inputs=examples, kernel_size=5, out_channels=16)\n  act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)\n  pre2, act2, params2 = conv_layer(\n      layer_id=2, inputs=act1, kernel_size=5, out_channels=16)\n  act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)\n  flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])\n  logits, params4 = linear_layer(\n      layer_id=4, inputs=flat_act3, output_size=num_labels)\n  loss = tf.reduce_mean(\n      tf.nn.sparse_softmax_cross_entropy_with_logits(\n          labels=labels, logits=logits))\n  accuracy = tf.reduce_mean(\n      tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))\n```\n\n## Register the layers and loss\n\n`layer_collection.auto_register_layers` automatically registers all the layers\nfor typical/standard models. However one must still manually register the loss\nfunction. In the case of cross-entropy loss functions on softmaxes this amounts\nto calling `layer_collection.register_softmax_cross_entropy_loss` with the\nlogits as an argument. Note that the inputs/outputs of non-parameterized layers\nsuch as max pooling and reshaping _do not_ need to be registered.\n\n```python\n  # Register parameters with graph_search.\n  tf.logging.info(\"Building KFAC Optimizer.\")\n  layer_collection = kfac.LayerCollection()\n  layer_collection.register_softmax_cross_entropy_loss(logits)\n  # Set the layer at params0 to use a diagonal approximation\n  # instead of default Kronecker factor based approximation.\n  layer_collection.define_linked_parameters(\n        params0, approximation=layer_collection.APPROX_DIAGONAL_NAME)\n  layer_collection.auto_register_layers()\n```\n\nIn the example above we demonstrate how to use a non-default Fisher\napproximation (diagonal) for one of the conv layers. (The default is usually\nKronecker-factored.) This is done by calling\n`layer_collection.define_linked_parameters`, which identifies the given\nvariables as being part of a particular layer, and sets the approximation that\nis to be used for that layer. Any registrations performed later, whether done by\nthe graph scanner or performed manually by the user, will use this approximation\n(unless overridden by the `approx` argument to the registration function).\n\nLayers can also be registered manually. This is required for types of layers\nthat the automatic graph scanner doesn't recognize.\n\nNote that One can also use a combination of manual and automatic registration by\ncalling `auto_register_layers()` after performing some manual registration. Any\nlayers registered manually before will be ignored by the scanner. We register\neach layer's inputs, outputs, and parameters with an instance of\n`LayerCollection`. For convolution layers, we use `register_conv2d`. For fully\nconnected (or linear) layers, `register_fully_connected`.\n\n```python\n  # Register parameters manually.\n  tf.logging.info(\"Building KFAC Optimizer.\")\n  layer_collection = kfac.LayerCollection()\n  layer_collection.register_softmax_cross_entropy_loss(logits)\n\n  layer_collection.register_conv2d(params0, (1, 1, 1, 1), \"SAME\", examples,\n                                   pre0,\n                                   approx=kfac_ff.APPROX_DIAGONAL_NAME)\n  layer_collection.register_conv2d(params2, (1, 1, 1, 1), \"SAME\", act1, pre2)\n  layer_collection.register_fully_connected(params4, flat_act3, logits)\n```\n\nIn this example we demonstrate how to use a non-default Fisher approximation\n(diagonal) for one of the layers. (The default is usually Kronecker-factored.)\nThis is done by passing `approx=kfac_ff.APPROX_DIAGONAL_NAME` to the\nregistration function `layer_collection.register_conv2d`. Note that if One has\nalready used `define_linked_parameters` to set the approximation then it is not\nrequired to specify it again via the `approx` argument.\n\n## Build the optimizer\n\nFinally, we instantiate the optimizer. In addition to the `learning_rate` and\n`momentum`, the optimizer has 2 additional hyperparameters,\n\n1.  `cov_ema_decay`: Check [hyper parameters][hyper_params] section for more\n    details.\n1.  `damping`: This is a critical parameter and needs to be tuned for good\n    performance. Check [hyper parameters][hyper_params] section for more\n    details.\n\n[hyper_params]:\nhttps://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md\n\n```python\n  # Train with K-FAC.\n  global_step = tf.train.get_or_create_global_step()\n  optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n      learning_rate=0.0001,\n      damping=0.001,\n      momentum=0.9,\n      cov_ema_decay=0.95,\n      invert_every=10,\n      cov_update_every=1,\n      layer_collection=layer_collection)\n  train_op = optimizer.minimize(loss, global_step=global_step)\n```\n\n## Fit the model\n\nOptimizing with KFAC is similar to using a standard optimizer, where there is an\n\"update op\" that computes and applies the update to the model's parameters.\nHowever, KFAC introduces two additional sets of ops that must also be executed\nas part of the algorithm (although not necessarily at every iteration). These\nare called the \"covariance update ops\" and \"inverse update ops\", respectively.\nThe covariance update ops update the various \"covariance\" matrices used to\ncompute the Fisher block approximations for the layers. The inverse update ops\nmeanwhile are responsible for computing inverses of the approximate Fisher\nblocks (using algorithms that exploit their special structure).\n\n`PeriodicInvCovUpdateKfacOpt`, which is a subclass of `KfacOptimizer` class,\nfolds these extra ops into the standard update op, so that they execute\nperiodically on certain iterations, according to the `cov_update_every` and\n`invert_every` arguments. Users seeking more fine-grained control of the timing\nand placement of the ops can use the base `KfacOptimizer` class.\n\n```python\n  with tf.train.MonitoredTrainingSession() as sess:\n    while not sess.should_stop():\n      global_step_, loss_, accuracy_, _, _ = sess.run(\n          [global_step, loss, accuracy, train_op])\n```\n\n## TIPS\n\n1.  Check the [hyper params tuning][hp_tune] section for more details on tuning\n    various KFAC parameters.\n\n[hp_tune]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md\n[mlp]: https://en.wikipedia.org/wiki/Multilayer_perceptron\n[preconditioner]: https://en.wikipedia.org/wiki/Preconditioner#Preconditioning_in_optimization\n"
  },
  {
    "path": "docs/examples/distributed_training.md",
    "content": "# Distributed Training\n\n## Table of Contents\n\n*   [Register the layers](#register-the-layers)\n*   [Build the optimizer](#build-the-optimizer)\n*   [Fit the model](#fit-the-model)\n*   [TIPS](#tips)\n    <br>\n\nThis example showcases how to use K-FAC in a distributed setting using\n`SyncReplicas` optimizer. If you are interested in using\n`tf.distribute.Strategy`, we support `MirroredStrategy` and `TPUStrategy`, with\nan example for `TPUStrategy` [here][tpu_strategy_example]. While most methods\nbenefit from increased compute, K-FAC particularly shines as the number of\nworkers (and, in turn, batch size) increases.\n\n[here][https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb]\n\n**Note:** This tutorial extends the single-machine\n[Convolutional example][conv_ex] to distributed training. It is highly\nrecommended you read that first, as shared bits are omitted below!\n\n[conv_ex]:\nhttps://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md\n\n**Note:** This tutorial expects you to be familiar with distributed training.\nCheck out https://www.tensorflow.org/deploy/distributed if this is new to you.\n\n**Example code**:\nhttps://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet_mnist_distributed_main.py\n\n## Build the Model\n\nWhen training on a single machine, one doesn't need to think about which\n\"device\" a variable is placed on (there's only 1 to choose from!). In a\ndistributed setting, variables live on [\"Parameter Servers\"][parameter-servers].\nPlacing a variable on a parameter server is as simple as using\n`tf.train.replica_device_setter()`, which is illustrated in the below code.\n\n```python\n  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):\n    pre0, act0, w0, b0 = conv_layer(\n        layer_id=0, inputs=examples, kernel_size=5, out_channels=16)\n    act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)\n    ...\n```\n\n[parameter-servers]: https://www.tensorflow.org/deploy/distributed\n\n## Register the layers\n\nLayer registration is identical to the single-machine case. See [\"Register the\nlayers\"][register-layers-conv] in the Convolutional example for details.\n\n[register-layers-conv]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md?#register-the-layers-and-loss\n\n## Build the optimizer\n\nLike the model itself, the K-FAC optimizer also creates variables. Don't forget\nto wrap it in a similar `replica_device_setter()` too!\n\n```python\n  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):\n    ...\n    optimizer = opt.KfacOptimizer(\n        learning_rate=0.0001,\n        cov_ema_decay=0.95,\n        damping=0.001,\n        layer_collection=layer_collection,\n        momentum=0.9)\n    ...\n```\n\n## Fit the model\n\nWhen training on a single-machine, a single training loop is responsible for\nexecuting all of K-FAC's training operations: updating weights, updating\nstatistics, and inverting the preconditioner matrix. As all of the work happens\non a single machine, one stands little to gain by parallelization.\n\nThere are different strategies of parallelizing the gradient, covariance and\ninverse computation across workers in a distributed setting. We will illustrate\nhere two such strategies that work specifically with `SyncReplicas` optimizer\nfor distributed training.\n\nThe first strategy for distributed training is to compute gradient in a\ndistributed fashion across all the workers, but have the inverse and covariance\nops executed only on the chief worker.\n\n**Code**:\nhttps://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet.py\n\n```python\n  optimizer = opt.KfacOptimizer(...)\n  sync_optimizer = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)\n  (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()\n\n  tf.logging.info(\"Starting training.\")\n  hooks = [sync_optimizer.make_session_run_hook(is_chief)]\n\n  def make_update_op(update_thunks):\n    update_ops = [thunk() for thunk in update_thunks]\n    return tf.group(*update_ops)\n\n  if is_chief:\n    cov_update_op = make_update_op(cov_update_thunks)\n    with tf.control_dependencies([cov_update_op]):\n      inverse_op = tf.cond(\n          tf.equal(tf.mod(global_step, invert_every), 0),\n          lambda: make_update_op(inv_update_thunks),\n          tf.no_op)\n      with tf.control_dependencies([inverse_op]):\n        train_op = sync_optimizer.minimize(loss, global_step=global_step)\n  else:\n    train_op = sync_optimizer.minimize(loss, global_step=global_step)\n```\n\nIn the second strategy, each worker's training loop is responsible for executing\nonly one of K-FAC's three training ops,\n\n1.  Compute gradients.\n1.  Workers updating covariance matrices can asynchronously update the moving\n    average similar to the way asynchronous SGD updates weights.\n1.  Workers inverting the preconditioning matrix can independently and\n    asynchronously invert its blocks, one at a time. Blocks are chosen according\n    to a randomly shuffled queue.\n\n```python\n  optimizer = opt.KfacOptimizer(...)\n  inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values())\n  sync_optimizer = tf.train.SyncReplicasOptimizer(\n      opt=optimizer,\n      replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))\n  train_op = sync_optimizer.minimize(loss, global_step=global_step)\n\n  with tf.train.MonitoredTrainingSession(...) as sess:\n    while not sess.should_stop():\n      if _is_gradient_task(task_id, num_worker_tasks):\n        learning_op = train_op\n      elif _is_cov_update_task(task_id, num_worker_tasks):\n        learning_op = optimizer.cov_update_op\n      elif _is_inv_update_task(task_id, num_worker_tasks):\n        learning_op = inv_update_queue.next_op(sess)\n\n      global_step_, loss_, statistics_, _ = sess.run(\n          [global_step, loss, statistics, learning_op])\n```\n\n## TIPS\n\n1.  Check the [hyper params tuning][hp_tune] section for more details on tuning\n    various KFAC parameters.\n\n[hp_tune]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md\n"
  },
  {
    "path": "docs/examples/parameters.md",
    "content": "# K-FAC Parameters.\n\n## Table of Contents\n\n*   [Damping](#damping)\n*   [Learning Rate](#learning-rate)\n*   [Subsample covariance computation](#subsample-covariance-computation)\n*   [KFAC norm constraint](#kfac-norm-constraint)\n*   [Covariance decay](#covariance-decay)\n*   [Train batch size](#train-batch-size)\n    <br>\n\nWe list below various parameters which can be tuned to improve training and run\ntime performance of K-FAC.\n\n## Damping\n\nDamping is a crucial aspect of K-FAC, as it is for any second order\noptimization/natural gradient method. Broadly speaking, it refers to the\npractice of penalizing or constraining the size of the update in various ways so\nthat it doesn't leave the local region where the quadratic approximation to the\nobjective (which is used to compute the update) remains accurate. This region\ncommonly referred to as the \"trust region\". In some literature damping is called\n\"regularization\" although we will avoid that term due to its related but\ndistinct meaning as a method to combat overfitting.\n\nThe damping strategy used in KFAC is to (approximately) add a multiple of the\nidentity to the Fisher before inverting it. This is essentially equivalent to\nenforcing that the update lie in a spherical trust region centered at the\ncurrent location in parameter space.\n\nThe `damping` parameter represents the multiple of identity which is used.\nHigher values correspond to smaller trust regions, although the precise\nrelationship between `damping` and the size of the trust region depends on the\nscale of the objective, and will vary from iteration to iteration. (If the loss\nfunction is multiplied by scalar 'alpha' then damping should be multiplied by\n'alpha' as well.) Higher values of `damping` can allow higher learning rates,\nbut as damping tends to infinity the KFAC updates will start to resemble regular\ngradient descent updates (scaled by `1/damping`).\n\nThe `damping` parameter depends on the scale of the loss function. `damping` is\na critical parameter that needs to be tuned. Options for tuning include a grid\nsweep (must be simultaneous with learning rate optimization - NOT independent)\nor auto-tuned using the Levenberg-Marquardt (LM) algorithm (see the [`Auto\nDamping`][auto_damping] section for further details). For grid sweeps a typical\nrange to consider would be logarithmically spaced values between `1e-5` to\n`100`, although the optimal value could be any non-negative real number in\nprinciple (because the scale of the loss is arbitrary). Another option for\ntuning `damping` is [`Population based training`][PBT] (PBT).\n\nRefer to section `6` of the [KFAC paper][kfac_paper] for a more detailed\ndiscussion of damping and how it can be used/tuned in KFAC\n\n[auto_damping]:\nhttps://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md\n[PBT]:\nhttps://arxiv.org/abs/1711.09846\n[kfac_paper]:\nhttps://arxiv.org/pdf/1503.05671.pdf\n\n## Learning Rate\n\nTypically sweep over values in the range 1e-5 to 100. It is important to tune\nthe learning in conjunction with damping, since the two are closely coupled\n(higher damping allows higher learning rates). The learning rate can also be\ntuned using PBT. Note that the optimal learning rate will be generally different\nfrom the learning rate used for SGD/RMSProp/Adam optimizer.\n\n## Subsample covariance computation\n\nIf you are using Conv layers and observe that the KFAC iterations is\nsignificantly slower than Adam or if you run out of memory then a possible\nremedy is to use subsampling in the covariance computation. To turn on\nsubsampling set `kfac_ff.sub_sample_inputs` to `True` and\n`kfac_ff.sub_sample_outer_products` to `True`. The former flag subsamples the\nbatch of inputs used for covariance computation and the later flag subsamples\nextracted patches based on the size of the covariance matrix. Check the\ndocumentation of `tensorflow_kfac.fisher_factors` for detailed explanation of\nvarious subsampling parameters. Also check [`Distributed training`][dist_train]\nsection for how to distribute the computation of these ops over multiple\ndevices.\n\n[dist_train]:\nhttps://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md\n\n## KFAC norm constraint\n\nScales the K-FAC update so that its approximate Fisher norm is bounded.\nTypically use an initial value of 1.0 and tune it using PBT or perform grid\nsearch. Norm constraint can used as an alternative to learning rate schedules.\nSee Section 5 of the [Distributed Second-Order Optimization using\nKronecker-Factored Approximations][ba_paper] paper for further details.\n\n[ba_paper]:\nhttps://jimmylba.github.io/papers/nsync.pdf\n\n## Covariance decay\n\nDuring the course of the algorithm, an exponential moving average tracks\nstatistics for each layer. Slower decays mean that the statistics are based on\nmore data, but will suffer more from the issue of staleness (because of the\nchanging model parameters). This parameter can usually be left at its default\nvalue but may occasionally matter for some problems. In such cases some\nreasonable values to sweep over are `[0.9, 0.95, 0.99, 0.999]`.\n\n## Train batch size\n\nTypically try using a larger batch size compared to training with\nSGD/RMSprop/Adam.\n"
  },
  {
    "path": "docs/index.md",
    "content": "# Home\n\nKronecker factored approximate curvature\n\n**K-FAC in TensorFlow** is an implementation of K-FAC, an approximate\nsecond-order optimization method, in TensorFlow. K-FAC can converge much\nfaster than SGD or Adam on certain neural network architectures (especially when\nusing larger batch sizes), but may be closer in performance on other\narchitectures (such as ResNets).\n\n## Table of Contents\n\n*   [What is K-FAC?](#what-is-k-fac)\n*   [Why should I use K-FAC?](#why-should-i-use-k-fac)\n*   [How do I use K-FAC?](#how-do-i-use-k-fac)\n\n## What is K-FAC?\n\nK-FAC, short for \"Kronecker-factored Approximate Curvature\", is an approximation\nto the [Natural Gradient][natural_gradient] algorithm designed specifically for\nneural networks. It maintains an approximation to the [Fisher Information\nmatrix][fisher_information], whose inverse is used as a preconditioner for\n(stochastic) gradient descent.\n\nK-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.\nHowever it is slightly more restrictive compared to SGD, Adam as it makes some\nassumptions on the structure of the model and the loss function.\n\nUnlike most optimizers, K-FAC exploits structure in the model itself (e.g. \"What\nare the weights for layer i?\"). As such, you must add some additional code while\nconstructing your model to use K-FAC.\n\n[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746\n[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form\n\n## Why should I use K-FAC?\n\nK-FAC can take advantage of the curvature of the optimization problem, resulting\nin **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same\nloss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See\nreference code [here][autoencoder-code] and plots comparing KFAC with SGD below.\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/g3doc/sgd_comparison.png?raw=true)\n\n[autoencoder-code]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/autoencoder_mnist.py\n\n## How do I use K-FAC?\n\nUsing K-FAC requires three steps,\n\n1.  Registering layer inputs, weights, and pre-activations with a\n    `kfac.LayerCollection`.\n2.  Register loss functions.\n3.  Minimizing the loss with a `kfac.PeriodicInvCovUpdateKfacOpt`.\n\n```python\nimport kfac\n# Build model.\nw = tf.get_variable(\"w\", ...)\nb = tf.get_variable(\"b\", ...)\nlogits = tf.matmul(x, w) + b\nloss = tf.reduce_mean(\n  tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))\n\n# Register loss.\nlayer_collection = kfac.LayerCollection()\nlayer_collection.register_softmax_cross_entropy_loss(logits)\n\n# Register layers.\nlayer_collection.auto_register_layers()\n\n# Construct training ops.\noptimizer = kfac.PeriodicInvCovUpdateKfacOpt(..., layer_collection=layer_collection)\ntrain_op = optimizer.minimize(loss)\n\n# Minimize loss.\nwith tf.Session() as sess:\n  ...\n  sess.run([train_op])\n```\n\nCheck out the Convnet training [example][convexamplesec] for more details. Also\ncheck [`PeriodicInvCovUpdate`][periodicincovupdate] optimizer to see how the\ncovariance and invariance ops placement and execution can be handled\nautomatically.\n\n[convexamplesec]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md\n[periodicincovupdate]: https://github.com/tensorflow/kfac/tree/master/kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py\n\n## Table of contents\n\n*   [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md)\n*   User Guide\n    *   [Keras](https://github.com/tensorflow/kfac/tree/master/kfac/python/keras/README.md)\n    *   [Convolutional](https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md)\n    *   [Auto damping](https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md)\n    *   [Distributed Training](https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md)\n    *   [Parameters](https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md)\n*   [Applications](https://github.com/tensorflow/kfac/tree/master/docs/applications.md)\n*   [Some KFAC-Papers](https://github.com/tensorflow/kfac/tree/master/docs/papers.md)\n*   [Contact](https://github.com/tensorflow/kfac/tree/master/docs/contact.md)\n"
  },
  {
    "path": "docs/papers.md",
    "content": "*   Martens, James, and Roger Grosse.\n    [\"Optimizing neural networks with Kronecker-factored approximate curvature.\"][kfac_paper]\n    International Conference on Machine Learning. 2015.\n*   Grosse, Roger, and James Martens.\n    [\"A Kronecker-factored approximate Fisher matrix for convolution layers.\"][kfac_conv_paper]\n    International Conference on Machine Learning. 2016.\n*   Ba, Jimmy, Roger Grosse, and James Martens. [\"Distributed Second-Order\n    Optimization using Kronecker-Factored Approximations.\"][distributed_kfac]\n    (2016).\n*   James Martens, Jimmy Ba, Matt Johnson. [\"Kronecker-factored Curvature\n    Approximations for Recurrent Neural Networks.\"][kfac_rnn_paper] ICLR. 2018.\n\n[kfac_paper]: https://arxiv.org/abs/1503.05671\n[kfac_conv_paper]: https://arxiv.org/abs/1602.01407\n[kfac_rnn_paper]: https://openreview.net/forum?id=HyMTkQZAb\n[distributed_kfac]: https://openreview.net/forum?id=SkkTMpjex\n"
  },
  {
    "path": "docs/sitemap.md",
    "content": "*   [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md)\n*   User Guide\n    *   [Convolutional](https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md)\n    *   [Auto damping](https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md)\n    *   [Distributed Training](https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md)\n    *   [Parameters](https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md)\n*   [Applications](https://github.com/tensorflow/kfac/tree/master/docs/applications.md)\n*   [Some KFAC-Papers](https://github.com/tensorflow/kfac/tree/master/docs/papers.md)\n*   [Contact](https://github.com/tensorflow/kfac/tree/master/docs/contact.md)\n"
  },
  {
    "path": "kfac/__init__.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Kronecker-factored Approximate Curvature Optimizer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# pylint: disable=unused-import,line-too-long\nfrom kfac.python import keras\n\nfrom kfac.python.ops import curvature_matrix_vector_products\nfrom kfac.python.ops import estimator\nfrom kfac.python.ops import fisher_blocks\nfrom kfac.python.ops import fisher_factors\nfrom kfac.python.ops import layer_collection\nfrom kfac.python.ops import linear_operator\nfrom kfac.python.ops import loss_functions\nfrom kfac.python.ops import op_queue\nfrom kfac.python.ops import optimizer\nfrom kfac.python.ops import placement\nfrom kfac.python.ops import utils\nfrom kfac.python.ops.kfac_utils import async_inv_cov_update_kfac_opt\nfrom kfac.python.ops.kfac_utils import data_reader\nfrom kfac.python.ops.kfac_utils import data_reader_alt\nfrom kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt\n\nfrom kfac.python.ops.tensormatch import graph_matcher\nfrom kfac.python.ops.tensormatch import graph_search\n\n# pylint: enable=unused-import\n\n# pylint: disable=invalid-name\nLayerCollection = layer_collection.LayerCollection\nKfacOptimizer = optimizer.KfacOptimizer\nPeriodicInvCovUpdateKfacOpt = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt\nAsyncInvCovUpdateKfacOpt = async_inv_cov_update_kfac_opt.AsyncInvCovUpdateKfacOpt\n\nCurvatureMatrixVectorProductComputer = curvature_matrix_vector_products.CurvatureMatrixVectorProductComputer\n\n# pylint: enable=invalid-name, line-too-long\n"
  },
  {
    "path": "kfac/examples/__init__.py",
    "content": "\n"
  },
  {
    "path": "kfac/examples/autoencoder_mnist.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Full implementation of deep autoencoder experiment from original K-FAC paper.\n\nThis script demonstrates training using KFAC optimizer, updating the damping\nparameter according to the Levenberg-Marquardt rule, and using the quadratic\nmodel method for adapting the learning rate and momentum parameters.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n# Dependency imports\nfrom absl import flags\nimport kfac\nimport sonnet as snt\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.examples import mnist\nfrom kfac.python.ops.kfac_utils import data_reader\nfrom kfac.python.ops.kfac_utils import data_reader_alt\n\n\n# Model parameters\n_ENCODER_SIZES = [1000, 500, 250, 30]\n_DECODER_SIZES = [250, 500, 1000]\n_NONLINEARITY = tf.tanh  # Note: sigmoid cannot be used with the default init.\n_WEIGHTS_INITIALIZER = None  # Default init\n\n\nflags.DEFINE_integer('train_steps', 10000, 'Number of training steps.')\n\nflags.DEFINE_integer('inverse_update_period', 5,\n                     '# of steps between computing inverse of Fisher factor '\n                     'matrices.')\n\nflags.DEFINE_integer('cov_update_period', 1,\n                     '# of steps between computing covaraiance matrices.')\n\nflags.DEFINE_integer('damping_adaptation_interval', 5,\n                     '# of steps between updating the damping parameter.')\n\nflags.DEFINE_integer('num_burnin_steps', 5, 'Number of steps at the '\n                     'start of training where the optimizer will only perform '\n                     'cov updates.')\n\nflags.DEFINE_integer('seed', 12345, 'Random seed')\n\nflags.DEFINE_float('learning_rate', 3e-3,\n                   'Learning rate to use when lrmu_adaptation=\"off\".')\n\nflags.DEFINE_float('momentum', 0.9,\n                   'Momentum decay value to use when '\n                   'lrmu_adaptation=\"off\" or \"only_lr\".')\n\nflags.DEFINE_float('damping', 1e-2, 'The fixed damping value to use. This is '\n                   'ignored if adapt_damping is True.')\n\nflags.DEFINE_float('l2_reg', 1e-5,\n                   'L2 regularization applied to weight matrices.')\n\nflags.DEFINE_boolean('update_damping_immediately', True, 'Adapt the damping '\n                     'immediately after the parameter update (i.e. in the same '\n                     'sess.run() call).  Only safe if everything is a resource '\n                     'variable.')\n\nflags.DEFINE_boolean('use_batch_size_schedule', True,\n                     'If True then we use the growing mini-batch schedule from '\n                     'the original K-FAC paper.')\n\nflags.DEFINE_integer('batch_size', 1024,\n                     'The size of the mini-batches to use if not using the '\n                     'schedule.')\n\nflags.DEFINE_string('lrmu_adaptation', 'on',\n                    'If set to \"on\" then we use the quadratic model '\n                    'based learning-rate and momentum adaptation method from '\n                    'the original paper. Note that this only works well in '\n                    'practice when use_batch_size_schedule=True. Can also '\n                    'be set to \"off\" and \"only_lr\", which turns '\n                    'it off, or uses a version where the momentum parameter '\n                    'is fixed (resp.).')\n\nflags.DEFINE_boolean('use_alt_data_reader', True,\n                     'If True we use the alternative data reader for MNIST '\n                     'that is faster for small datasets.')\n\nflags.DEFINE_string('device', '/gpu:0',\n                    'The device to run the major ops on.')\n\nflags.DEFINE_boolean('adapt_damping', True,\n                     'If True we use the LM rule for damping adaptation as '\n                     'described in the original K-FAC paper.')\n\n# When using damping adaptation it is advisable to start with a high\n# value. This value is probably far too high to use for most neural nets\n# if you aren't using damping adaptation. (Although it always depends on\n# the scale of the loss.)\nflags.DEFINE_float('initial_damping', 150.0,\n                   'The initial damping value to use when adapt_damping is '\n                   'True.')\n\nflags.DEFINE_string('optimizer', 'kfac',\n                    'The optimizer to use. Can be kfac or adam. If adam is '\n                    'used the various K-FAC hyperparameter map roughly on to '\n                    'their Adam equivalents.')\n\nflags.DEFINE_boolean('auto_register_layers', True,\n                     'If True we use the automatic registration feature '\n                     'which relies on scanning the TF graph. Otherwise '\n                     'registration is done manually by this script during '\n                     'the construction of the model.')\n\nflags.DEFINE_boolean('use_keras_model', False,\n                     'If True, we use a Keras version of the autoencoder '\n                     'model. Only works when auto_register_layers=True.')\n\nflags.DEFINE_boolean('use_sequential_for_keras', True,\n                     'If True, we construct the Keras model using the '\n                     'Sequential class.')\n\nflags.DEFINE_boolean('use_control_flow_v2', False, 'If True, we use Control '\n                     'Flow V2. Defaults to False.')\n\n\nFLAGS = flags.FLAGS\n\n\ndef make_train_op(minibatch,\n                  batch_size,\n                  batch_loss,\n                  layer_collection,\n                  loss_fn,\n                  prev_train_batch=None,\n                  placement_strategy=None,\n                  print_logs=False,\n                  tf_replicator=None):\n  \"\"\"Constructs optimizer and train op.\n\n  Args:\n    minibatch: A list/tuple of Tensors (typically representing the current\n      mini-batch of input images and labels).\n    batch_size: Tensor of shape (). Size of the training mini-batch.\n    batch_loss: Tensor of shape (). Mini-batch loss tensor.\n    layer_collection: LayerCollection object. Registry for model parameters.\n      Required when using a K-FAC optimizer.\n    loss_fn: Function which takes as input a mini-batch and returns the loss.\n    prev_train_batch: `Tensor` of the previous training batch, can be accessed\n      from the data_reader.CachedReader cached_batch property. (Default: None)\n    placement_strategy: `str`, the placement_strategy argument for\n      `KfacOptimizer`. (Default: None)\n    print_logs: `Bool`. If True we print logs using K-FAC's built-in\n      tf.print-based logs printer. (Default: False)\n    tf_replicator: A Replicator object or None. If not None, K-FAC will set\n        itself up to work inside of the provided TF-Replicator object.\n        (Default: None)\n\n  Returns:\n    train_op: Op that can be used to update model parameters.\n    optimizer: Optimizer used to produce train_op.\n\n  Raises:\n    ValueError: If layer_collection is None when K-FAC is selected as an\n      optimization method.\n  \"\"\"\n  global_step = tf.train.get_or_create_global_step()\n\n  if FLAGS.optimizer == 'kfac':\n    if FLAGS.lrmu_adaptation == 'on':\n      learning_rate = None\n      momentum = None\n      momentum_type = 'qmodel'\n    elif FLAGS.lrmu_adaptation == 'only_lr':\n      learning_rate = None\n      momentum = FLAGS.momentum\n      momentum_type = 'qmodel_fixedmu'\n    elif FLAGS.lrmu_adaptation == 'off':\n      learning_rate = FLAGS.learning_rate\n      momentum = FLAGS.momentum\n      # momentum_type = 'regular'\n      momentum_type = 'adam'\n\n    if FLAGS.adapt_damping:\n      damping = FLAGS.initial_damping\n    else:\n      damping = FLAGS.damping\n\n    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n        invert_every=FLAGS.inverse_update_period,\n        cov_update_every=FLAGS.cov_update_period,\n        learning_rate=learning_rate,\n        damping=damping,\n        cov_ema_decay=0.95,\n        momentum=momentum,\n        momentum_type=momentum_type,\n        layer_collection=layer_collection,\n        batch_size=batch_size,\n        num_burnin_steps=FLAGS.num_burnin_steps,\n        adapt_damping=FLAGS.adapt_damping,\n        l2_reg=FLAGS.l2_reg,\n        placement_strategy=placement_strategy,\n        print_logs=print_logs,\n        tf_replicator=tf_replicator,\n        # Note that many of the arguments below don't do anything when\n        # adapt_damping=False.\n        update_damping_immediately=FLAGS.update_damping_immediately,\n        is_chief=True,\n        prev_train_batch=prev_train_batch,  # We don't actually need this unless\n                                            # update_damping_immediately is\n                                            # False.\n        loss=batch_loss,\n        loss_fn=loss_fn,\n        damping_adaptation_decay=0.95,\n        damping_adaptation_interval=FLAGS.damping_adaptation_interval,\n        min_damping=1e-6,\n        train_batch=minibatch,\n        )\n\n  elif FLAGS.optimizer == 'adam':\n    optimizer = tf.train.AdamOptimizer(\n        learning_rate=FLAGS.learning_rate,\n        beta1=FLAGS.momentum,\n        epsilon=FLAGS.damping,\n        beta2=0.99)\n\n  return optimizer.minimize(batch_loss, global_step=global_step), optimizer\n\n\nclass AutoEncoder(snt.AbstractModule):\n  \"\"\"Simple autoencoder module.\"\"\"\n\n  def __init__(self,\n               input_size,\n               regularizers=None,\n               initializers=None,\n               custom_getter=None,\n               name='AutoEncoder'):\n    super(AutoEncoder, self).__init__(custom_getter=custom_getter, name=name)\n\n    if initializers is None:\n      initializers = {'w': tf.glorot_uniform_initializer(),\n                      'b': tf.zeros_initializer()}\n    if regularizers is None:\n      regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),\n                      'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}\n\n    with self._enter_variable_scope():\n      self._encoder = snt.nets.MLP(\n          output_sizes=_ENCODER_SIZES,\n          regularizers=regularizers,\n          initializers=initializers,\n          custom_getter=custom_getter,\n          activation=_NONLINEARITY,\n          activate_final=False)\n      self._decoder = snt.nets.MLP(\n          output_sizes=_DECODER_SIZES + [input_size],\n          regularizers=regularizers,\n          initializers=initializers,\n          custom_getter=custom_getter,\n          activation=_NONLINEARITY,\n          activate_final=False)\n\n  def _build(self, inputs):\n    code = self._encoder(inputs)\n    output = self._decoder(code)\n\n    return output\n\n\nclass MLPManualReg(snt.AbstractModule):\n\n  def __init__(self,\n               output_sizes,\n               regularizers=None,\n               initializers=None,\n               custom_getter=None,\n               activation=_NONLINEARITY,\n               activate_final=False,\n               name='MLP'):\n\n    super(MLPManualReg, self).__init__(custom_getter=custom_getter, name=name)\n\n    self._output_sizes = output_sizes\n    self._activation = activation\n    self._activate_final = activate_final\n\n    with self._enter_variable_scope():\n      self._layers = [snt.Linear(self._output_sizes[i],\n                                 name='linear_{}'.format(i),\n                                 initializers=initializers,\n                                 regularizers=regularizers,\n                                 custom_getter=custom_getter,\n                                 use_bias=True)\n                      for i in range(len(self._output_sizes))]\n\n  def _build(self, inputs, layer_collection=None):\n    net = inputs\n    for i in range(len(self._output_sizes)):\n      layer_inputs = net\n      net = self._layers[i](net)\n      layer_outputs = net\n      params = (self._layers[i].w, self._layers[i].b)\n\n      if layer_collection is not None:\n        layer_collection.register_fully_connected(params,\n                                                  layer_inputs,\n                                                  layer_outputs,\n                                                  reuse=False)\n\n      if i < len(self._output_sizes) - 1 or self._activate_final:\n        net = self._activation(net)\n\n    return net\n\n\nclass AutoEncoderManualReg(snt.AbstractModule):\n  \"\"\"Simple autoencoder module.\"\"\"\n\n  def __init__(self,\n               input_size,\n               regularizers=None,\n               initializers=None,\n               custom_getter=None,\n               name='AutoEncoder'):\n    super(AutoEncoderManualReg, self).__init__(custom_getter=custom_getter,\n                                               name=name)\n\n    if initializers is None:\n      initializers = {'w': tf.glorot_uniform_initializer(),\n                      'b': tf.zeros_initializer()}\n    if regularizers is None:\n      regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),\n                      'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}\n\n    with self._enter_variable_scope():\n      self._encoder = MLPManualReg(\n          output_sizes=_ENCODER_SIZES,\n          regularizers=regularizers,\n          initializers=initializers,\n          custom_getter=custom_getter,\n          activation=_NONLINEARITY,\n          activate_final=False)\n      self._decoder = MLPManualReg(\n          output_sizes=_DECODER_SIZES + [input_size],\n          regularizers=regularizers,\n          initializers=initializers,\n          custom_getter=custom_getter,\n          activation=_NONLINEARITY,\n          activate_final=False)\n\n  def _build(self, inputs, layer_collection=None):\n    code = self._encoder(inputs, layer_collection=layer_collection)\n    output = self._decoder(code, layer_collection=layer_collection)\n\n    return output\n\n\ndef get_keras_autoencoder(**input_kwargs):\n  \"\"\"Returns autoencoder made with Keras.\n\n  Args:\n    **input_kwargs: Arguments to pass to tf.keras.layers.Input. You must include\n      either the 'shape' or 'tensor' kwarg.\n\n  Returns:\n    A tf.keras.Model, the Autoencoder.\n  \"\"\"\n  layers = tf.keras.layers\n  regularizers = tf.keras.regularizers\n\n  dense_kwargs = {\n      'kernel_initializer': tf.glorot_uniform_initializer(),\n      'bias_initializer': tf.zeros_initializer(),\n      'kernel_regularizer': regularizers.l2(l=FLAGS.l2_reg),\n      'bias_regularizer': regularizers.l2(l=FLAGS.l2_reg),\n  }\n\n  if FLAGS.use_sequential_for_keras:\n    model = tf.keras.Sequential()\n    # Create Encoder\n    model.add(layers.Input(**input_kwargs))\n    for size in _ENCODER_SIZES[:-1]:\n      model.add(layers.Dense(\n          size, activation=_NONLINEARITY, **dense_kwargs))\n    model.add(layers.Dense(_ENCODER_SIZES[-1], **dense_kwargs))\n\n    # Create Decoder\n    for size in _DECODER_SIZES:\n      model.add(layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs))\n    model.add(layers.Dense(784, **dense_kwargs))\n\n  else:\n    # Make sure you always wrap the input in keras\n    inputs = layers.Input(**input_kwargs)\n\n    x = inputs\n    # Create Encoder\n    for size in _ENCODER_SIZES[:-1]:\n      x = layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)(x)\n    x = layers.Dense(_ENCODER_SIZES[-1], **dense_kwargs)(x)\n\n    # Create Decoder\n    for size in _DECODER_SIZES:\n      x = layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)(x)\n    x = layers.Dense(784, **dense_kwargs)(x)\n\n    model = tf.keras.Model(inputs=inputs, outputs=x)\n\n  return model\n\n\ndef compute_squared_error(logits, targets):\n  \"\"\"Compute mean squared error.\"\"\"\n  return tf.reduce_sum(\n      tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0))\n\n\ndef compute_loss(logits=None,\n                 labels=None,\n                 return_error=False,\n                 model=None):\n  \"\"\"Compute loss value.\"\"\"\n  if FLAGS.use_keras_model:\n    total_regularization_loss = tf.add_n(model.losses)\n  else:\n    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)\n    total_regularization_loss = tf.add_n(graph_regularizers)\n\n  loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,\n                                                        labels=labels)\n  loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))\n  regularized_loss = loss + total_regularization_loss\n\n  if return_error:\n    squared_error = compute_squared_error(logits, labels)\n    return regularized_loss, squared_error\n\n  return regularized_loss\n\n\ndef load_mnist():\n  \"\"\"Creates MNIST dataset and wraps it inside cached data reader.\n\n  Returns:\n    cached_reader: `data_reader.CachedReader` instance which wraps MNIST\n      dataset.\n    num_examples: int. The number of training examples.\n  \"\"\"\n  # Wrap the data set into cached_reader which provides variable sized training\n  # and caches the read train batch.\n\n  if not FLAGS.use_alt_data_reader:\n    # Version 1 using data_reader.py (slow!)\n    dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)\n    if FLAGS.use_batch_size_schedule:\n      max_batch_size = num_examples\n    else:\n      max_batch_size = FLAGS.batch_size\n\n    # Shuffle before repeat is correct unless you want repeat cases in the\n    # same batch.\n    dataset = (dataset.shuffle(num_examples).repeat()\n               .batch(max_batch_size).prefetch(5))\n    dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()\n\n    # This version of CachedDataReader requires the dataset to be shuffled\n    return data_reader.CachedDataReader(dataset, max_batch_size), num_examples\n\n  else:\n    # Version 2 using data_reader_alt.py (faster)\n    images, labels, num_examples = mnist.load_mnist_as_tensors(\n        flatten_images=True)\n    dataset = (images, labels)\n\n    # This version of CachedDataReader requires the dataset to NOT be shuffled\n    return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples\n\n\ndef _get_batch_size_schedule(minibatch_maxsize):\n  \"\"\"Returns training batch size schedule.\"\"\"\n  minibatch_maxsize_targetiter = 500\n  minibatch_startsize = 1000\n\n  div = (float(minibatch_maxsize_targetiter-1)\n         / math.log(float(minibatch_maxsize)/minibatch_startsize, 2))\n  return [\n      min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize)\n      for k in range(minibatch_maxsize_targetiter)\n  ]\n\n\ndef construct_train_quants():\n  \"\"\"Returns tensors and optimizer required to run the autoencoder.\"\"\"\n  with tf.device(FLAGS.device):\n    # Load dataset.\n    cached_reader, num_examples = load_mnist()\n    batch_size_schedule = _get_batch_size_schedule(num_examples)\n    batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size')\n\n    minibatch = cached_reader(batch_size)\n    features, _ = minibatch\n\n    if FLAGS.auto_register_layers:\n      if FLAGS.use_keras_model:\n        training_model = get_keras_autoencoder(tensor=features)\n      else:\n        training_model = AutoEncoder(784)\n    else:\n      training_model = AutoEncoderManualReg(784)\n\n    layer_collection = kfac.LayerCollection()\n\n    def loss_fn(minibatch, logits=None, return_error=False):\n      features, _ = minibatch\n      if logits is None:\n        logits = training_model(features)\n\n      return compute_loss(\n          logits=logits,\n          labels=features,\n          return_error=return_error,\n          model=training_model)\n\n    if FLAGS.use_keras_model:\n      logits = training_model.output\n    else:\n      if FLAGS.auto_register_layers:\n        logits = training_model(features)\n      else:\n        logits = training_model(features, layer_collection=layer_collection)\n\n    (batch_loss, batch_error) = loss_fn(\n        minibatch, logits=logits, return_error=True)\n\n    # Make sure never to confuse this with register_softmax_cross_entropy_loss!\n    layer_collection.register_sigmoid_cross_entropy_loss(logits,\n                                                         seed=FLAGS.seed + 1)\n    if FLAGS.auto_register_layers:\n      layer_collection.auto_register_layers()\n\n    # Make training op\n    train_op, opt = make_train_op(\n        minibatch,\n        batch_size,\n        batch_loss,\n        layer_collection,\n        loss_fn=loss_fn,\n        prev_train_batch=cached_reader.cached_batch)\n\n  return train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size\n\n\ndef main(_):\n\n  # If using update_damping_immediately resource variables must be enabled.\n  # Would recommend always enabling them anyway.\n  if FLAGS.update_damping_immediately:\n    tf.enable_resource_variables()\n\n  if FLAGS.use_control_flow_v2:\n    tf.enable_control_flow_v2()\n\n  if not FLAGS.auto_register_layers and FLAGS.use_keras_model:\n    raise ValueError('Require auto_register_layers=True when using Keras '\n                     'model.')\n\n  tf.set_random_seed(FLAGS.seed)\n  (train_op, opt, batch_loss, batch_error, batch_size_schedule,\n   batch_size) = construct_train_quants()\n\n  global_step = tf.train.get_or_create_global_step()\n\n  if FLAGS.optimizer == 'kfac':\n    # We need to put the control depenency on train_op here so that we are\n    # guaranteed to get the up-to-date values of these various quantities.\n    # Otherwise there is a race condition and we might get the old values,\n    # nondeterministically. Another solution would be to get these values in\n    # a separate sess.run call, but this can sometimes cause problems with\n    # training frameworks that use hooks (see the comments below).\n    with tf.control_dependencies([train_op]):\n      learning_rate = opt.learning_rate\n      momentum = opt.momentum\n      damping = opt.damping\n      rho = opt.rho\n      qmodel_change = opt.qmodel_change\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  config = tf.ConfigProto(allow_soft_placement=True)\n\n  # It's good practice to put everything into a single sess.run call. The\n  # reason is that certain \"training frameworks\" like to run hooks at each\n  # sess.run call, and there is an implicit expectation there will only\n  # be one sess.run call every \"iteration\" of the \"optimizer\". For example,\n  # a framework might try to print the loss at each sess.run call, causing\n  # the mini-batch to be advanced, thus completely breaking the \"cached\n  # batch\" mechanism that the damping adaptation method may rely on. (Plus\n  # there will also be the extra cost of having to reevaluate the loss\n  # twice.)  That being said we don't completely do that here because it's\n  # inconvenient.\n\n  # Train model.\n  with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30,\n                                         config=config) as sess:\n    for _ in range(FLAGS.train_steps):\n      i = sess.run(global_step)\n\n      if FLAGS.use_batch_size_schedule:\n        batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)]\n      else:\n        batch_size_ = FLAGS.batch_size\n\n      if FLAGS.optimizer == 'kfac':\n        (_, batch_loss_, batch_error_, learning_rate_, momentum_, damping_,\n         rho_, qmodel_change_) = sess.run([train_op, batch_loss, batch_error,\n                                           learning_rate, momentum, damping,\n                                           rho, qmodel_change],\n                                          feed_dict={batch_size: batch_size_})\n      else:\n        _, batch_loss_, batch_error_ = sess.run(\n            [train_op, batch_loss, batch_error],\n            feed_dict={batch_size: batch_size_})\n\n      # Print training stats.\n      tf.logging.info(\n          'iteration: %d', i)\n      tf.logging.info(\n          'mini-batch size: %d | mini-batch loss = %f | mini-batch error = %f ',\n          batch_size_, batch_loss_, batch_error_)\n\n      if FLAGS.optimizer == 'kfac':\n        tf.logging.info(\n            'learning_rate = %f | momentum = %f',\n            learning_rate_, momentum_)\n        tf.logging.info(\n            'damping = %f | rho = %f | qmodel_change = %f',\n            damping_, rho_, qmodel_change_)\n\n      tf.logging.info('----')\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.app.run(main)\n"
  },
  {
    "path": "kfac/examples/autoencoder_mnist_tpu_estimator.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Implementation of Deep AutoEncoder from Martens & Grosse (2015).\n\nThis script demonstrates training on TPUs with TPU Estimator using the KFAC\noptimizer, updating the damping parameter according to the\nLevenberg-Marquardt rule, and using the quadratic model method for adapting\nthe learning rate and momentum parameters.\n\nSee third_party/tensorflow_kfac/google/examples/ae_tpu_xm_launcher.py\nfor an example Borg launch script.  If you can't access this launch script,\nsome important things to know about running K-FAC on TPUs (at least for this\nexample) are that you must use higher-precision matrix multiplications.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nfrom absl import flags\nimport kfac\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import tpu as contrib_tpu\n\nfrom kfac.examples import autoencoder_mnist\nfrom kfac.examples import mnist\n\n\nflags.DEFINE_integer('save_checkpoints_steps', 500,\n                     'Number of iterations between model checkpoints.')\n\nflags.DEFINE_integer('iterations_per_loop', 100,\n                     'Number of iterations in a TPU training loop.')\n\nflags.DEFINE_string('model_dir', '', 'Model dir.')\n\nflags.DEFINE_string('master', None,\n                    'GRPC URL of the master '\n                    '(e.g. grpc://ip.address.of.tpu:8470).')\n\n\nFLAGS = flags.FLAGS\n\n\ndef make_train_op(minibatch,\n                  batch_loss,\n                  layer_collection,\n                  loss_fn):\n  \"\"\"Constructs optimizer and train op.\n\n  Args:\n    minibatch: Tuple[Tensor, Tensor] representing the current batch of input\n      images and labels.\n    batch_loss: Tensor of shape (), Loss with respect to minibatch to be\n      minimzed.\n    layer_collection: LayerCollection object. Registry for model parameters.\n      Required when using a K-FAC optimizer.\n    loss_fn: A function that when called constructs the graph to compute the\n      model loss on the current minibatch.  Returns a Tensor of the loss scalar.\n\n  Returns:\n    train_op: Op that can be used to update model parameters.\n    optimizer: The KFAC optimizer used to produce train_op.\n\n  Raises:\n    ValueError: If layer_collection is None when K-FAC is selected as an\n      optimization method.\n  \"\"\"\n  # Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own\n  # cross-replica syncronization automatically!\n\n  return autoencoder_mnist.make_train_op(\n      minibatch=minibatch,\n      batch_size=minibatch[0].get_shape().as_list()[0],\n      batch_loss=batch_loss,\n      layer_collection=layer_collection,\n      loss_fn=loss_fn,\n      prev_train_batch=None,\n      placement_strategy='replica_round_robin',\n      )\n\n\ndef compute_squared_error(logits, targets):\n  \"\"\"Compute mean squared error.\"\"\"\n  return tf.reduce_sum(\n      tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0))\n\n\ndef compute_loss(logits, labels):\n  \"\"\"Compute loss value.\"\"\"\n  graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)\n  total_regularization_loss = tf.add_n(graph_regularizers)\n  loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,\n                                                        labels=labels)\n  loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))\n  regularized_loss = loss + total_regularization_loss\n  return regularized_loss\n\n\ndef mnist_input_fn(params):\n  dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)\n\n  # Shuffle before repeat is correct unless you want repeat cases in the\n  # same batch.\n  dataset = (\n      dataset.shuffle(num_examples).repeat().batch(\n          params['batch_size'],\n          drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))\n  return dataset\n\n\ndef print_tensors(**tensors):\n  \"\"\"Host call function to print Tensors from the TPU during training.\"\"\"\n  print_op = tf.no_op()\n  for name in sorted(tensors):\n    with tf.control_dependencies([print_op]):\n      tensor = tensors[name]\n      if name in ['error', 'loss']:\n        tensor = tf.reduce_mean(tensor)\n      print_op = tf.Print(tensor, [tensor], message=name + '=')\n  with tf.control_dependencies([print_op]):\n    return tf.Print(0., [0.], message='------')\n\n\ndef _model_fn(features, labels, mode, params):\n  \"\"\"Estimator model_fn for an autoencoder with adaptive damping.\"\"\"\n  del params\n  layer_collection = kfac.LayerCollection()\n  training_model_fn = autoencoder_mnist.AutoEncoder(784)\n\n  def loss_fn(minibatch, logits=None):\n    \"\"\"Compute the model loss given a batch of inputs.\n\n    Args:\n      minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images\n        and labels.\n      logits: `Tensor` for the current batch of logits. If None then reuses the\n        AutoEncoder to compute them.\n\n    Returns:\n      `Tensor` for the batch loss.\n    \"\"\"\n    features, labels = minibatch\n    del labels\n    if logits is None:\n      # Note we do not need to do anything like\n      # `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):`\n      # here because Sonnet takes care of variable reuse for us as long as we\n      # call the same `training_model_fn` module.  Otherwise we would need to\n      # use variable reusing here.\n      logits = training_model_fn(features)\n    batch_loss = compute_loss(logits=logits, labels=features)\n    return batch_loss\n\n  logits = training_model_fn(features)\n  pre_update_batch_loss = loss_fn((features, labels), logits=logits)\n  pre_update_batch_error = compute_squared_error(logits, features)\n\n  if mode == tf.estimator.ModeKeys.TRAIN:\n    # Make sure never to confuse this with register_softmax_cross_entropy_loss!\n    layer_collection.register_sigmoid_cross_entropy_loss(logits,\n                                                         seed=FLAGS.seed + 1)\n    layer_collection.auto_register_layers()\n\n    global_step = tf.train.get_or_create_global_step()\n    train_op, kfac_optimizer = make_train_op(\n        (features, labels),\n        pre_update_batch_loss,\n        layer_collection,\n        loss_fn)\n\n    tensors_to_print = {\n        'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0),\n        'momentum': tf.expand_dims(kfac_optimizer.momentum, 0),\n        'damping': tf.expand_dims(kfac_optimizer.damping, 0),\n        'global_step': tf.expand_dims(global_step, 0),\n        'loss': tf.expand_dims(pre_update_batch_loss, 0),\n        'error': tf.expand_dims(pre_update_batch_error, 0),\n    }\n    if FLAGS.adapt_damping:\n      tensors_to_print['qmodel_change'] = tf.expand_dims(\n          kfac_optimizer.qmodel_change, 0)\n      tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0)\n\n    return contrib_tpu.TPUEstimatorSpec(\n        mode=mode,\n        loss=pre_update_batch_loss,\n        train_op=train_op,\n        host_call=(print_tensors, tensors_to_print),\n        eval_metrics=None)\n\n  else:  # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}:\n    return contrib_tpu.TPUEstimatorSpec(\n        mode=mode,\n        loss=pre_update_batch_loss,\n        eval_metrics=None)\n\n\ndef make_tpu_run_config(master, seed, model_dir, iterations_per_loop,\n                        save_checkpoints_steps):\n  return contrib_tpu.RunConfig(\n      master=master,\n      evaluation_master=master,\n      model_dir=model_dir,\n      save_checkpoints_steps=save_checkpoints_steps,\n      cluster=None,\n      tf_random_seed=seed,\n      tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop))\n\n\ndef main(argv):\n\n  if FLAGS.use_control_flow_v2:\n    tf.enable_control_flow_v2()\n\n  del argv  # Unused.\n  tf.set_random_seed(FLAGS.seed)\n  # Invert using cholesky decomposition + triangular solve.  This is the only\n  # code path for matrix inversion supported on TPU right now.\n  kfac.utils.set_global_constants(posdef_inv_method='cholesky')\n  kfac.fisher_factors.set_global_constants(\n      eigenvalue_decomposition_threshold=10000)\n\n  config = make_tpu_run_config(\n      FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop,\n      FLAGS.save_checkpoints_steps)\n\n  estimator = contrib_tpu.TPUEstimator(\n      use_tpu=True,\n      model_fn=_model_fn,\n      config=config,\n      train_batch_size=FLAGS.batch_size,\n      eval_batch_size=1024)\n\n  estimator.train(\n      input_fn=mnist_input_fn,\n      max_steps=FLAGS.train_steps,\n      hooks=[])\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.app.run(main)\n"
  },
  {
    "path": "kfac/examples/autoencoder_mnist_tpu_strategy.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Implementation of Deep AutoEncoder from Martens & Grosse (2015).\n\nThis script demonstrates training on TPUs with TPUStrategy using the KFAC\noptimizer, updating the damping parameter according to the\nLevenberg-Marquardt rule, and using the quadratic model method for adapting\nthe learning rate and momentum parameters.\n\nSee third_party/tensorflow_kfac/google/examples/ae_tpu_xm_launcher.py\nfor an example Borg launch script.  If you can't access this launch script,\nsome important things to know about running K-FAC on TPUs (at least for this\nexample) are that you must use high-precision matrix multiplications.\niterations_per_loop is not relevant when using TPU Strategy, but you must set\nit to 1 when using TPU Estimator.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nfrom absl import flags\nimport kfac\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.examples import autoencoder_mnist\nfrom kfac.examples import mnist\n\n\n# TODO(znado): figure out the bug with this and update_damping_immediately=True.\n# TODO(znado): Add checkpointing code to the training loop.\nflags.DEFINE_integer('save_checkpoints_steps', 500,\n                     'Number of iterations between model checkpoints.')\nflags.DEFINE_string('model_dir', '', 'Model dir.')\n\n# iterations_per_loop is not used with TPU Strategy. We keep the flag so the\n# Estimator launching script can be used.\nflags.DEFINE_integer('iterations_per_loop', 1,\n                     'Number of iterations in a TPU training loop.')\n\nflags.DEFINE_string('master', None,\n                    'GRPC URL of the master '\n                    '(e.g. grpc://ip.address.of.tpu:8470).')\n\n\nFLAGS = flags.FLAGS\n\n\ndef make_train_op(minibatch,\n                  batch_loss,\n                  layer_collection,\n                  loss_fn):\n  \"\"\"Constructs optimizer and train op.\n\n  Args:\n    minibatch: Tuple[Tensor, Tensor] representing the current batch of input\n      images and labels.\n    batch_loss: Tensor of shape (), Loss with respect to minibatch to be\n      minimzed.\n    layer_collection: LayerCollection object. Registry for model parameters.\n      Required when using a K-FAC optimizer.\n    loss_fn: A function that when called constructs the graph to compute the\n      model loss on the current minibatch.  Returns a Tensor of the loss scalar.\n\n  Returns:\n    train_op: Op that can be used to update model parameters.\n    optimizer: The KFAC optimizer used to produce train_op.\n\n  Raises:\n    ValueError: If layer_collection is None when K-FAC is selected as an\n      optimization method.\n  \"\"\"\n  # Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own\n  # cross-replica syncronization automatically!\n\n  return autoencoder_mnist.make_train_op(\n      minibatch=minibatch,\n      batch_size=minibatch[0].get_shape().as_list()[0],\n      batch_loss=batch_loss,\n      layer_collection=layer_collection,\n      loss_fn=loss_fn,\n      prev_train_batch=None,\n      placement_strategy='replica_round_robin',\n      )\n\n\ndef compute_squared_error(logits, targets):\n  \"\"\"Compute mean squared error.\"\"\"\n  return tf.reduce_sum(\n      tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0))\n\n\ndef compute_loss(logits, labels, model):\n  \"\"\"Compute loss value.\"\"\"\n  loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,\n                                                        labels=labels)\n  regularization_loss = tf.reduce_sum(model.losses)\n  crossentropy_loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))\n  return crossentropy_loss + regularization_loss\n\n\ndef mnist_input_fn(batch_size):\n  dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)\n\n  # Shuffle before repeat is correct unless you want repeat cases in the\n  # same batch.\n  dataset = (dataset.shuffle(num_examples)\n             .repeat()\n             .batch(batch_size, drop_remainder=True)\n             .prefetch(tf.data.experimental.AUTOTUNE))\n  return dataset\n\n\ndef _train_step(batch):\n  \"\"\"Estimator model_fn for an autoencoder with adaptive damping.\"\"\"\n  features, labels = batch\n  model = autoencoder_mnist.get_keras_autoencoder(tensor=features)\n\n  def loss_fn(minibatch, logits=None):\n    \"\"\"Compute the model loss given a batch of inputs.\n\n    Args:\n      minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images\n        and labels.\n      logits: `Tensor` for the current batch of logits. If None then reuses the\n        AutoEncoder to compute them.\n\n    Returns:\n      `Tensor` for the batch loss.\n    \"\"\"\n    features, labels = minibatch\n    del labels\n    if logits is None:\n      logits = model(features)\n    batch_loss = compute_loss(logits=logits, labels=features, model=model)\n    return batch_loss\n\n  logits = model.output\n  pre_update_batch_loss = loss_fn((features, labels), logits)\n  pre_update_batch_error = compute_squared_error(logits, features)\n\n  # binary_crossentropy corresponds to sigmoid_crossentropy.\n  layer_collection = kfac.keras.utils.get_layer_collection(\n      model, 'binary_crossentropy', seed=FLAGS.seed + 1)\n\n  global_step = tf.train.get_or_create_global_step()\n  train_op, kfac_optimizer = make_train_op(\n      (features, labels),\n      pre_update_batch_loss,\n      layer_collection,\n      loss_fn)\n  tensors_to_print = {\n      'learning_rate': kfac_optimizer.learning_rate,\n      'momentum': kfac_optimizer.momentum,\n      'damping': kfac_optimizer.damping,\n      'global_step': global_step,\n      'loss': pre_update_batch_loss,\n      'error': pre_update_batch_error,\n  }\n  if FLAGS.adapt_damping:\n    tensors_to_print['qmodel_change'] = kfac_optimizer.qmodel_change\n    tensors_to_print['rho'] = kfac_optimizer.rho\n\n  with tf.control_dependencies([train_op]):\n    return {k: tf.identity(v) for k, v in tensors_to_print.items()}\n\n\ndef train():\n  \"\"\"Trains the Autoencoder using TPU Strategy.\"\"\"\n  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(\n      tpu=FLAGS.master)\n  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)\n  tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)\n\n  with tpu_strategy.scope():\n    data = mnist_input_fn(batch_size=FLAGS.batch_size)\n    train_iterator = tpu_strategy.make_dataset_iterator(data)\n    tensor_dict = tpu_strategy.experimental_run(_train_step, train_iterator)\n    for k, v in tensor_dict.items():\n      if k in ('loss', 'error'):   # Losses are NOT scaled for num replicas.\n        tensor_dict[k] = tpu_strategy.reduce(tf.distribute.ReduceOp.MEAN, v)\n      else:  # Other tensors (hyperparameters) are identical across replicas.\n        # experimental_local_results gives you a tuple of per-replica values.\n        tensor_dict[k] = tpu_strategy.experimental_local_results(v)\n\n  config = tf.ConfigProto(allow_soft_placement=True)\n  with tf.Session(cluster_resolver.master(), config=config) as session:\n    session.run(tf.global_variables_initializer())\n    session.run(train_iterator.initializer)\n    print('Starting training.')\n    for step in range(FLAGS.train_steps):\n      values_dict = session.run(tensor_dict)\n      print('Training Step: {}'.format(step))\n      for k, v in values_dict.items():\n        print('{}: {}'.format(k, v))\n    print('Done training.')\n\n\ndef main(argv):\n  del argv  # Unused.\n  tf.set_random_seed(FLAGS.seed)\n  # Invert using cholesky decomposition + triangular solve.  This is the only\n  # code path for matrix inversion supported on TPU right now.\n  kfac.utils.set_global_constants(posdef_inv_method='cholesky')\n  kfac.fisher_factors.set_global_constants(\n      eigenvalue_decomposition_threshold=10000)\n\n  train()\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.app.run(main)\n"
  },
  {
    "path": "kfac/examples/classifier_mnist.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"A simple MNIST classifier example.\n\nThis script demonstrates training using KFAC optimizer, updating the damping\nparameter according to the Levenberg-Marquardt rule, and using the quadratic\nmodel method for adapting the learning rate and momentum parameters.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n# Dependency imports\nfrom absl import flags\nimport kfac\nimport sonnet as snt\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.examples import mnist\nfrom kfac.python.ops.kfac_utils import data_reader\nfrom kfac.python.ops.kfac_utils import data_reader_alt\n\n\n# Model parameters\n_NONLINEARITY = tf.nn.relu  # can also be tf.nn.tanh\n_POOL = 'MAX'  # can also be 'AVG'\n\n\nflags.DEFINE_integer('train_steps', 10000, 'Number of training steps.')\n\nflags.DEFINE_integer('inverse_update_period', 5,\n                     '# of steps between computing inverse of Fisher factor '\n                     'matrices.')\n\nflags.DEFINE_integer('cov_update_period', 1,\n                     '# of steps between computing covaraiance matrices.')\n\nflags.DEFINE_integer('damping_adaptation_interval', 5,\n                     '# of steps between updating the damping parameter.')\n\nflags.DEFINE_integer('num_burnin_steps', 5, 'Number of steps the at the '\n                     'start of training where the optimizer will only perform '\n                     'cov updates. Will not work on CrossShardOptimizer. See '\n                     'PeriodicInvCovUpdateKfacOpt for details.')\n\nflags.DEFINE_integer('seed', 12345, 'Random seed')\n\nflags.DEFINE_float('learning_rate', 3e-3,\n                   'Learning rate to use when lrmu_adaptation=\"off\".')\n\nflags.DEFINE_float('momentum', 0.9,\n                   'Momentum decay value to use when '\n                   'lrmu_adaptation=\"off\" or \"only_lr\".')\n\nflags.DEFINE_float('damping', 1e-2, 'The fixed damping value to use. This is '\n                   'ignored if adapt_damping is True.')\n\nflags.DEFINE_float('l2_reg', 1e-5,\n                   'L2 regularization applied to weight matrices.')\n\nflags.DEFINE_boolean('update_damping_immediately', True, 'Adapt the damping '\n                     'immediately after the parameter update (i.e. in the same '\n                     'sess.run() call).  Only safe if everything is a resource '\n                     'variable.')\n\nflags.DEFINE_boolean('use_batch_size_schedule', True,\n                     'If True then we use the growing mini-batch schedule from '\n                     'the original K-FAC paper.')\nflags.DEFINE_integer('batch_size', 1024,\n                     'The size of the mini-batches to use if not using the '\n                     'schedule.')\n\nflags.DEFINE_string('lrmu_adaptation', 'on',\n                    'If set to \"on\" then we use the quadratic model '\n                    'based learning-rate and momentum adaptation method from '\n                    'the original paper. Note that this only works well in '\n                    'practice when use_batch_size_schedule=True. Can also '\n                    'be set to \"off\" and \"only_lr\", which turns '\n                    'it off, or uses a version where the momentum parameter '\n                    'is fixed (resp.).')\n\nflags.DEFINE_boolean('use_alt_data_reader', True,\n                     'If True we use the alternative data reader for MNIST '\n                     'that is faster for small datasets.')\n\nflags.DEFINE_string('device', '/gpu:0',\n                    'The device to run the major ops on.')\n\nflags.DEFINE_boolean('adapt_damping', True,\n                     'If True we use the LM rule for damping adaptation as '\n                     'described in the original K-FAC paper.')\n\n# When using damping adaptation it is advisable to start with a high\n# value. This value is probably far too high to use for most neural nets\n# if you aren't using damping adaptation. (Although it always depends on\n# the scale of the loss.)\nflags.DEFINE_float('initial_damping', 0.1,\n                   'The initial damping value to use when adapt_damping is '\n                   'True.')\n\nflags.DEFINE_string('optimizer', 'kfac',\n                    'The optimizer to use. Can be kfac or adam. If adam is '\n                    'used the various kfac hyperparameter map roughly on to '\n                    'their Adam equivalents.')\n\nflags.DEFINE_float('polyak_decay', 0.995, 'Rate of decay for Polyak averaging.')\n\nflags.DEFINE_integer('eval_every', 50,\n                     'Interval to print total training loss.')\n\nflags.DEFINE_boolean('use_sua_approx', False,\n                     'If True we use the SUA approximation for conv layers.')\n\nflags.DEFINE_string('dtype', 'float32',\n                    'The DTYPE to use for all computations. Can by float32 '\n                    'or float64.')\n\n\nflags.DEFINE_boolean('use_custom_patches_op', False,\n                     'If True we use the custom XLA implementation of the op '\n                     'which computes the second moment of patch vectors.')\n\n\nFLAGS = flags.FLAGS\n\n\nclass Model(snt.AbstractModule):\n  \"\"\"CNN model for MNIST data.\"\"\"\n\n  def _build(self, inputs):\n\n    if FLAGS.l2_reg:\n      regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),\n                      'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}\n    else:\n      regularizers = None\n\n    reshape = snt.BatchReshape([28, 28, 1])\n\n    conv = snt.Conv2D(2, 5, padding=snt.SAME, regularizers=regularizers)\n    act = _NONLINEARITY(conv(reshape(inputs)))\n\n    pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,\n                      padding=snt.SAME, strides=(2, 2))\n\n    conv = snt.Conv2D(4, 5, padding=snt.SAME, regularizers=regularizers)\n    act = _NONLINEARITY(conv(pool))\n\n    pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,\n                      padding=snt.SAME, strides=(2, 2))\n\n    flatten = snt.BatchFlatten()(pool)\n\n    linear = snt.Linear(32, regularizers=regularizers)(flatten)\n\n    return snt.Linear(10, regularizers=regularizers)(linear)\n\n\ndef make_train_op(minibatch,\n                  batch_size,\n                  batch_loss,\n                  layer_collection,\n                  loss_fn,\n                  prev_train_batch=None,\n                  placement_strategy=None,\n                  print_logs=False,\n                  tf_replicator=None):\n  \"\"\"Constructs optimizer and train op.\n\n  Args:\n    minibatch: A list/tuple of Tensors (typically representing the current\n      mini-batch of input images and labels).\n    batch_size: Tensor of shape (). Size of the training mini-batch.\n    batch_loss: Tensor of shape (). Mini-batch loss tensor.\n    layer_collection: LayerCollection object. Registry for model parameters.\n      Required when using a K-FAC optimizer.\n    loss_fn: Function which takes as input a mini-batch and returns the loss.\n    prev_train_batch: `Tensor` of the previous training batch, can be accessed\n      from the data_reader.CachedReader cached_batch property. (Default: None)\n    placement_strategy: `str`, the placement_strategy argument for\n      `KfacOptimizer`. (Default: None)\n    print_logs: `Bool`. If True we print logs using K-FAC's built-in\n      tf.print-based logs printer. (Default: False)\n    tf_replicator: A Replicator object or None. If not None, K-FAC will set\n        itself up to work inside of the provided TF-Replicator object.\n        (Default: None)\n\n  Returns:\n    train_op: Op that can be used to update model parameters.\n    optimizer: Optimizer used to produce train_op.\n\n  Raises:\n    ValueError: If layer_collection is None when K-FAC is selected as an\n      optimization method.\n  \"\"\"\n  global_step = tf.train.get_or_create_global_step()\n\n  if FLAGS.optimizer == 'kfac':\n    if FLAGS.lrmu_adaptation == 'on':\n      learning_rate = None\n      momentum = None\n      momentum_type = 'qmodel'\n    elif FLAGS.lrmu_adaptation == 'only_lr':\n      learning_rate = None\n      momentum = FLAGS.momentum\n      momentum_type = 'qmodel_fixedmu'\n    elif FLAGS.lrmu_adaptation == 'off':\n      learning_rate = FLAGS.learning_rate\n      momentum = FLAGS.momentum\n      # momentum_type = 'regular'\n      momentum_type = 'adam'\n\n    if FLAGS.adapt_damping:\n      damping = FLAGS.initial_damping\n    else:\n      damping = FLAGS.damping\n\n    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n        invert_every=FLAGS.inverse_update_period,\n        cov_update_every=FLAGS.cov_update_period,\n        learning_rate=learning_rate,\n        damping=damping,\n        cov_ema_decay=0.95,\n        momentum=momentum,\n        momentum_type=momentum_type,\n        layer_collection=layer_collection,\n        batch_size=batch_size,\n        num_burnin_steps=FLAGS.num_burnin_steps,\n        adapt_damping=FLAGS.adapt_damping,\n        # Note that many of the arguments below don't do anything when\n        # adapt_damping=False.\n        update_damping_immediately=FLAGS.update_damping_immediately,\n        is_chief=True,\n        prev_train_batch=prev_train_batch,\n        loss=batch_loss,\n        loss_fn=loss_fn,\n        damping_adaptation_decay=0.9,\n        damping_adaptation_interval=FLAGS.damping_adaptation_interval,\n        min_damping=1e-6,\n        l2_reg=FLAGS.l2_reg,\n        train_batch=minibatch,\n        placement_strategy=placement_strategy,\n        print_logs=print_logs,\n        tf_replicator=tf_replicator,\n        dtype=FLAGS.dtype,\n        )\n\n  elif FLAGS.optimizer == 'adam':\n    optimizer = tf.train.AdamOptimizer(\n        learning_rate=FLAGS.learning_rate,\n        beta1=FLAGS.momentum,\n        epsilon=FLAGS.damping,\n        beta2=0.99)\n\n  return optimizer.minimize(batch_loss, global_step=global_step), optimizer\n\n\ndef compute_loss(logits=None,\n                 labels=None,\n                 return_error=False,\n                 use_regularizer=True):\n  \"\"\"Compute loss value.\"\"\"\n  loss_matrix = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,\n                                                               labels=labels)\n  total_loss = tf.reduce_mean(loss_matrix, axis=0)\n\n  if use_regularizer:\n    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)\n    total_regularization_loss = tf.add_n(graph_regularizers)\n\n    total_loss += tf.cast(total_regularization_loss, dtype=total_loss.dtype)\n\n  if return_error:\n    error = 1.0 - tf.reduce_mean(tf.cast(\n        tf.equal(labels, tf.argmax(logits, axis=1, output_type=tf.int32)),\n        tf.float32))\n    return total_loss, error\n\n  return total_loss\n\n\ndef load_mnist():\n  \"\"\"Creates MNIST dataset and wraps it inside cached data reader.\n\n  Returns:\n    cached_reader: `data_reader.CachedReader` instance which wraps MNIST\n      dataset.\n    num_examples: int. The number of training examples.\n  \"\"\"\n  # Wrap the data set into cached_reader which provides variable sized training\n  # and caches the read train batch.\n\n  if not FLAGS.use_alt_data_reader:\n    # Version 1 using data_reader.py (slow!)\n    dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)\n    if FLAGS.use_batch_size_schedule:\n      max_batch_size = num_examples\n    else:\n      max_batch_size = FLAGS.batch_size\n\n    # Shuffle before repeat is correct unless you want repeat cases in the\n    # same batch.\n    dataset = (dataset.shuffle(num_examples).repeat()\n               .batch(max_batch_size).prefetch(5))\n    dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()\n\n    # This version of CachedDataReader requires the dataset to be shuffled\n    return data_reader.CachedDataReader(dataset, max_batch_size), num_examples\n\n  else:\n    # Version 2 using data_reader_alt.py (faster)\n    images, labels, num_examples = mnist.load_mnist_as_tensors(\n        flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype))\n    dataset = (images, labels)\n\n    # This version of CachedDataReader requires the dataset to NOT be shuffled\n    return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples\n\n\ndef _get_batch_size_schedule(num_examples):\n  \"\"\"Returns training batch size schedule.\"\"\"\n  minibatch_maxsize_targetiter = 100  # We use a smaller target iter here than\n                                      # in the autoencoder example.\n  minibatch_maxsize = num_examples\n  minibatch_startsize = 1000\n\n  div = (float(minibatch_maxsize_targetiter-1)\n         / math.log(float(minibatch_maxsize)/minibatch_startsize, 2))\n  return [\n      min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize)\n      for k in range(minibatch_maxsize_targetiter)\n  ]\n\n\ndef group_assign(dest, source):\n  return tf.group(*(d.assign(s) for d, s in zip(dest, source)))\n\n\ndef make_eval_ops(train_vars, ema):\n  # This does evaluation with and without Polyak averaging.\n\n  images, labels, _ = mnist.load_mnist_as_tensors(\n      flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype))\n\n  eval_model = Model()\n  eval_model(images)  # We need this dummy call because the variables won't\n                      # exist otherwise.\n  eval_vars = eval_model.variables\n\n  update_eval_model = group_assign(eval_vars, train_vars)\n\n  with tf.control_dependencies([update_eval_model]):\n    logits = eval_model(images)\n    eval_loss, eval_error = compute_loss(\n        logits=logits, labels=labels, return_error=True)\n\n    with tf.control_dependencies([eval_loss, eval_error]):\n      update_eval_model_avg = group_assign(\n          eval_vars, (ema.average(t) for t in train_vars))\n\n      with tf.control_dependencies([update_eval_model_avg]):\n        logits = eval_model(images)\n        eval_loss_avg, eval_error_avg = compute_loss(\n            logits=logits, labels=labels, return_error=True)\n\n  return eval_loss, eval_error, eval_loss_avg, eval_error_avg\n\n\ndef construct_train_quants():\n  with tf.device(FLAGS.device):\n    # Load dataset.\n    cached_reader, num_examples = load_mnist()\n    batch_size_schedule = _get_batch_size_schedule(num_examples)\n    batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size')\n\n    minibatch = cached_reader(batch_size)\n    features, _ = minibatch\n    training_model = Model()\n    layer_collection = kfac.LayerCollection()\n\n    if FLAGS.use_sua_approx:\n      layer_collection.set_default_conv2d_approximation('kron_sua')\n\n    ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay,\n                                            zero_debias=True)\n\n    def loss_fn(minibatch, logits=None, return_error=False):\n      features, labels = minibatch\n      if logits is None:\n        logits = training_model(features)\n      return compute_loss(\n          logits=logits,\n          labels=labels,\n          return_error=return_error)\n\n    logits = training_model(features)\n\n    (batch_loss, batch_error) = loss_fn(\n        minibatch, logits=logits, return_error=True)\n\n    # Make sure never to confuse this with register_sigmoid_cross_entropy_loss!\n    layer_collection.register_softmax_cross_entropy_loss(logits,\n                                                         seed=FLAGS.seed + 1)\n    layer_collection.auto_register_layers()\n\n    train_vars = training_model.variables\n\n    # Make training op:\n    train_op, opt = make_train_op(\n        minibatch,\n        batch_size,\n        batch_loss,\n        layer_collection,\n        loss_fn=loss_fn,\n        prev_train_batch=cached_reader.cached_batch)\n\n    with tf.control_dependencies([train_op]):\n      train_op = ema.apply(train_vars)\n\n    # We clear out the regularizers collection when creating our evaluation\n    # graph (which uses different variables). It is important that we do this\n    # only after the train op is constructed, since the minimize() will call\n    # into the loss function (which includes the regularizer):\n    tf.get_default_graph().clear_collection(tf.GraphKeys.REGULARIZATION_LOSSES)\n\n    # These aren't run in the same sess.run call as train_op:\n    (eval_loss, eval_error,\n     eval_loss_avg, eval_error_avg) = make_eval_ops(train_vars, ema)\n\n  return (train_op, opt, batch_loss, batch_error, batch_size_schedule,\n          batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg)\n\n\ndef main(_):\n\n  # If using update_damping_immediately resource variables must be enabled.\n  if FLAGS.update_damping_immediately:\n    tf.enable_resource_variables()\n\n  if not FLAGS.use_sua_approx:\n    if FLAGS.use_custom_patches_op:\n      kfac.fisher_factors.set_global_constants(\n          use_patches_second_moment_op=True\n          )\n    else:\n      # Temporary measure to save memory with giant batches:\n      kfac.fisher_factors.set_global_constants(\n          sub_sample_inputs=True,\n          inputs_to_extract_patches_factor=0.2)\n\n  tf.set_random_seed(FLAGS.seed)\n  (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size,\n   eval_loss, eval_error,\n   eval_loss_avg, eval_error_avg) = construct_train_quants()\n\n  global_step = tf.train.get_or_create_global_step()\n\n  if FLAGS.optimizer == 'kfac':\n    # We need to put the control depenency on train_op here so that we are\n    # guaranteed to get the up-to-date values of these various quantities.\n    # Otherwise there is a race condition and we might get the old values,\n    # nondeterministically. Another solution would be to get these values in\n    # a separate sess.run call, but this can sometimes cause problems with\n    # training frameworks that use hooks (see the comments below).\n    with tf.control_dependencies([train_op]):\n      learning_rate = opt.learning_rate\n      momentum = opt.momentum\n      damping = opt.damping\n      rho = opt.rho\n      qmodel_change = opt.qmodel_change\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  config = tf.ConfigProto(allow_soft_placement=True)\n\n  # Train model.\n\n  # It's good practice to put everything into a single sess.run call. The\n  # reason is that certain \"training frameworks\" like to run hooks at each\n  # sess.run call, and there is an implicit expectation there will only\n  # be one sess.run call every \"iteration\" of the \"optimizer\". For example,\n  # a framework might try to print the loss at each sess.run call, causing\n  # the mini-batch to be advanced, thus completely breaking the \"cached\n  # batch\" mechanism that the damping adaptation method may rely on. (Plus\n  # there will also be the extra cost of having to reevaluate the loss\n  # twice.)  That being said we don't completely do that here because it's\n  # inconvenient.\n  with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30,\n                                         config=config) as sess:\n    for _ in range(FLAGS.train_steps):\n      i = sess.run(global_step)\n\n      if FLAGS.use_batch_size_schedule:\n        batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)]\n      else:\n        batch_size_ = FLAGS.batch_size\n\n      if FLAGS.optimizer == 'kfac':\n        (_, batch_loss_, batch_error_, learning_rate_, momentum_, damping_,\n         rho_, qmodel_change_) = sess.run([train_op, batch_loss, batch_error,\n                                           learning_rate, momentum, damping,\n                                           rho, qmodel_change],\n                                          feed_dict={batch_size: batch_size_})\n      else:\n        _, batch_loss_, batch_error_ = sess.run(\n            [train_op, batch_loss, batch_error],\n            feed_dict={batch_size: batch_size_})\n\n      # Print training stats.\n      tf.logging.info(\n          'iteration: %d', i)\n      tf.logging.info(\n          'mini-batch size: %d | mini-batch loss = %f | mini-batch error = %f ',\n          batch_size_, batch_loss_, batch_error_)\n\n      if FLAGS.optimizer == 'kfac':\n        tf.logging.info(\n            'learning_rate = %f | momentum = %f',\n            learning_rate_, momentum_)\n        tf.logging.info(\n            'damping = %f | rho = %f | qmodel_change = %f',\n            damping_, rho_, qmodel_change_)\n\n      # \"Eval\" here means just compute stuff on the full training set.\n      if (i+1) % FLAGS.eval_every == 0:\n        eval_loss_, eval_error_, eval_loss_avg_, eval_error_avg_ = sess.run(\n            [eval_loss, eval_error, eval_loss_avg, eval_error_avg])\n        tf.logging.info('-----------------------------------------------------')\n        tf.logging.info('eval_loss = %f | eval_error = %f',\n                        eval_loss_, eval_error_)\n        tf.logging.info('eval_loss_avg = %f | eval_error_avg = %f',\n                        eval_loss_avg_, eval_error_avg_)\n        tf.logging.info('-----------------------------------------------------')\n      else:\n        tf.logging.info('----')\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.app.run(main)\n"
  },
  {
    "path": "kfac/examples/classifier_mnist_tpu_estimator.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"A simple MNIST classifier example.\n\nThis script demonstrates training on TPUs with TPU Estimator using the KFAC\noptimizer, updating the damping parameter according to the\nLevenberg-Marquardt rule, and using the quadratic model method for adapting\nthe learning rate and momentum parameters.\n\nSee third_party/tensorflow_kfac/google/examples/classifier_tpu_xm_launcher.py\nfor an example Borg launch script.  If you can't access this launch script,\nsome important things to know about running K-FAC on TPUs (at least for this\nexample) are that you must use higher-precision matrix multiplications.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nfrom absl import flags\nimport kfac\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import tpu as contrib_tpu\n\nfrom kfac.examples import classifier_mnist\nfrom kfac.examples import mnist\n\n\nflags.DEFINE_integer('save_checkpoints_steps', 500,\n                     'Number of iterations between model checkpoints.')\n\nflags.DEFINE_integer('iterations_per_loop', 100,\n                     'Number of iterations in a TPU training loop.')\n\nflags.DEFINE_string('model_dir', '', 'Model dir.')\n\nflags.DEFINE_string('master', None,\n                    'GRPC URL of the master '\n                    '(e.g. grpc://ip.address.of.tpu:8470).')\n\n\nFLAGS = flags.FLAGS\n\n\ndef make_train_op(minibatch,\n                  batch_loss,\n                  layer_collection,\n                  loss_fn):\n  \"\"\"Constructs optimizer and train op.\n\n  Args:\n    minibatch: Tuple[Tensor, Tensor] representing the current batch of input\n      images and labels.\n    batch_loss: Tensor of shape (), Loss with respect to minibatch to be\n      minimzed.\n    layer_collection: LayerCollection object. Registry for model parameters.\n      Required when using a K-FAC optimizer.\n    loss_fn: A function that when called constructs the graph to compute the\n      model loss on the current minibatch.  Returns a Tensor of the loss scalar.\n\n  Returns:\n    train_op: Op that can be used to update model parameters.\n    optimizer: The KFAC optimizer used to produce train_op.\n\n  Raises:\n    ValueError: If layer_collection is None when K-FAC is selected as an\n      optimization method.\n  \"\"\"\n  # Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own\n  # cross-replica syncronization automatically!\n\n  return classifier_mnist.make_train_op(\n      minibatch=minibatch,\n      batch_size=minibatch[0].get_shape().as_list()[0],\n      batch_loss=batch_loss,\n      layer_collection=layer_collection,\n      loss_fn=loss_fn,\n      prev_train_batch=None,\n      placement_strategy='replica_round_robin',\n      )\n\n\ndef mnist_input_fn(params):\n  dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)\n\n  # Shuffle before repeat is correct unless you want repeat cases in the\n  # same batch.\n  dataset = (\n      dataset.shuffle(num_examples).repeat().batch(\n          params['batch_size'],\n          drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))\n  return dataset\n\n\ndef print_tensors(**tensors):\n  \"\"\"Host call function to print Tensors from the TPU during training.\"\"\"\n  print_op = tf.no_op()\n  for name in sorted(tensors):\n    with tf.control_dependencies([print_op]):\n      tensor = tensors[name]\n      if name in ['error', 'loss']:\n        tensor = tf.reduce_mean(tensor)\n      print_op = tf.Print(tensor, [tensor], message=name + '=')\n  with tf.control_dependencies([print_op]):\n    return tf.Print(0., [0.], message='------')\n\n\ndef _model_fn(features, labels, mode, params):\n  \"\"\"Estimator model_fn for an autoencoder with adaptive damping.\"\"\"\n  del params\n\n  training_model = classifier_mnist.Model()\n  layer_collection = kfac.LayerCollection()\n\n  def loss_fn(minibatch,\n              logits=None,\n              return_error=False):\n\n    features, labels = minibatch\n    if logits is None:\n      # Note we do not need to do anything like\n      # `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):`\n      # here because Sonnet takes care of variable reuse for us as long as we\n      # call the same `training_model` module.  Otherwise we would need to\n      # use variable reusing here.\n      logits = training_model(features)\n\n    return classifier_mnist.compute_loss(logits=logits,\n                                         labels=labels,\n                                         return_error=return_error)\n\n  logits = training_model(features)\n\n  pre_update_batch_loss, pre_update_batch_error = loss_fn(\n      (features, labels),\n      logits=logits,\n      return_error=True)\n\n  global_step = tf.train.get_or_create_global_step()\n\n  if mode == tf.estimator.ModeKeys.TRAIN:\n    layer_collection.register_softmax_cross_entropy_loss(logits,\n                                                         seed=FLAGS.seed + 1)\n    layer_collection.auto_register_layers()\n\n    train_op, kfac_optimizer = make_train_op(\n        (features, labels),\n        pre_update_batch_loss,\n        layer_collection,\n        loss_fn)\n\n    tensors_to_print = {\n        'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0),\n        'momentum': tf.expand_dims(kfac_optimizer.momentum, 0),\n        'damping': tf.expand_dims(kfac_optimizer.damping, 0),\n        'global_step': tf.expand_dims(global_step, 0),\n        'loss': tf.expand_dims(pre_update_batch_loss, 0),\n        'error': tf.expand_dims(pre_update_batch_error, 0),\n    }\n\n    if FLAGS.adapt_damping:\n      tensors_to_print['qmodel_change'] = tf.expand_dims(\n          kfac_optimizer.qmodel_change, 0)\n      tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0)\n\n    return contrib_tpu.TPUEstimatorSpec(\n        mode=mode,\n        loss=pre_update_batch_loss,\n        train_op=train_op,\n        host_call=(print_tensors, tensors_to_print),\n        eval_metrics=None)\n\n  else:  # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}:\n    return contrib_tpu.TPUEstimatorSpec(\n        mode=mode,\n        loss=pre_update_batch_loss,\n        eval_metrics=None)\n\n\ndef make_tpu_run_config(master, seed, model_dir, iterations_per_loop,\n                        save_checkpoints_steps):\n  return contrib_tpu.RunConfig(\n      master=master,\n      evaluation_master=master,\n      model_dir=model_dir,\n      save_checkpoints_steps=save_checkpoints_steps,\n      cluster=None,\n      tf_random_seed=seed,\n      tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop))\n\n\ndef main(argv):\n  del argv  # Unused.\n\n  # If using update_damping_immediately resource variables must be enabled.\n  # (Although they probably will be by default on TPUs.)\n  if FLAGS.update_damping_immediately:\n    tf.enable_resource_variables()\n\n  tf.set_random_seed(FLAGS.seed)\n  # Invert using cholesky decomposition + triangular solve.  This is the only\n  # code path for matrix inversion supported on TPU right now.\n  kfac.utils.set_global_constants(posdef_inv_method='cholesky')\n  kfac.fisher_factors.set_global_constants(\n      eigenvalue_decomposition_threshold=10000)\n\n  if not FLAGS.use_sua_approx:\n    if FLAGS.use_custom_patches_op:\n      kfac.fisher_factors.set_global_constants(\n          use_patches_second_moment_op=True\n          )\n    else:\n      # Temporary measure to save memory with giant batches:\n      kfac.fisher_factors.set_global_constants(\n          sub_sample_inputs=True,\n          inputs_to_extract_patches_factor=0.1)\n\n  config = make_tpu_run_config(\n      FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop,\n      FLAGS.save_checkpoints_steps)\n\n  estimator = contrib_tpu.TPUEstimator(\n      use_tpu=True,\n      model_fn=_model_fn,\n      config=config,\n      train_batch_size=FLAGS.batch_size,\n      eval_batch_size=1024)\n\n  estimator.train(\n      input_fn=mnist_input_fn,\n      max_steps=FLAGS.train_steps,\n      hooks=[])\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.app.run(main)\n"
  },
  {
    "path": "kfac/examples/convnet.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Train a ConvNet on MNIST using K-FAC.\n\nThis library demonstrates how to use K-FAC to train a 5-layer ConvNet on MNIST\nusing K-FAC.\n\nNote that this example is basically untuned and is not meant to work as an\nactual demonstration of the power of the method. It may not even converge. It\nmerely demonstrates the how to set up K-FAC to run under the various standard\nmodes of operation in Tensorflow, like SyncReplicas, Estimator, etc.\n\nFor an example of the method tuned properly and working well, see for example\nthe autoencoder_mnist.py example, which replicates the exact experiment from\nthe original K-FAC paper.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport kfac\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.examples import mnist\n\n\n__all__ = [\n    \"conv_layer\",\n    \"fc_layer\",\n    \"max_pool_layer\",\n    \"build_model\",\n    \"minimize_loss_single_machine\",\n    \"distributed_grads_only_and_ops_chief_worker\",\n    \"distributed_grads_and_ops_dedicated_workers\",\n    \"train_mnist_single_machine\",\n    \"train_mnist_distributed_sync_replicas\",\n    \"train_mnist_multitower\"\n]\n\n\n# Inverse update ops will be run every _INVERT_EVRY iterations.\n_INVERT_EVERY = 10\n\n# Covariance matrices will be update  _COV_UPDATE_EVERY iterations.\n_COV_UPDATE_EVERY = 1\n\n# Displays loss every _REPORT_EVERY iterations.\n_REPORT_EVERY = 10\n\n# Use manual registration\n_USE_MANUAL_REG = False\n\n\ndef fc_layer(layer_id, inputs, output_size):\n  \"\"\"Builds a fully connected layer.\n\n  Args:\n    layer_id: int. Integer ID for this layer's variables.\n    inputs: Tensor of shape [num_examples, input_size]. Each row corresponds\n      to a single example.\n    output_size: int. Number of output dimensions after fully connected layer.\n\n  Returns:\n    preactivations: Tensor of shape [num_examples, output_size]. Values of the\n      layer immediately before the activation function.\n    activations: Tensor of shape [num_examples, output_size]. Values of the\n      layer immediately after the activation function.\n    params: Tuple of (weights, bias), parameters for this layer.\n  \"\"\"\n  # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.\n  layer = tf.layers.Dense(\n      output_size,\n      kernel_initializer=tf.random_normal_initializer(),\n      name=\"fc_%d\" % layer_id)\n  preactivations = layer(inputs)\n  activations = tf.nn.tanh(preactivations)\n\n  # layer.weights is a list. This converts it a (hashable) tuple.\n  return preactivations, activations, (layer.kernel, layer.bias)\n\n\ndef conv_layer(layer_id, inputs, kernel_size, out_channels):\n  \"\"\"Builds a convolutional layer with ReLU non-linearity.\n\n  Args:\n    layer_id: int. Integer ID for this layer's variables.\n    inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row\n      corresponds to a single example.\n    kernel_size: int. Width and height of the convolution kernel. The kernel is\n      assumed to be square.\n    out_channels: int. Number of output features per pixel.\n\n  Returns:\n    preactivations: Tensor of shape [num_examples, width, height, out_channels].\n      Values of the layer immediately before the activation function.\n    activations: Tensor of shape [num_examples, width, height, out_channels].\n      Values of the layer immediately after the activation function.\n    params: Tuple of (kernel, bias), parameters for this layer.\n  \"\"\"\n  # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.\n  layer = tf.layers.Conv2D(\n      out_channels,\n      kernel_size=[kernel_size, kernel_size],\n      kernel_initializer=tf.random_normal_initializer(stddev=0.01),\n      padding=\"SAME\",\n      name=\"conv_%d\" % layer_id)\n  preactivations = layer(inputs)\n  activations = tf.nn.relu(preactivations)\n\n  # layer.weights is a list. This converts it a (hashable) tuple.\n  return preactivations, activations, (layer.kernel, layer.bias)\n\n\ndef max_pool_layer(layer_id, inputs, kernel_size, stride):\n  \"\"\"Build a max-pooling layer.\n\n  Args:\n    layer_id: int. Integer ID for this layer's variables.\n    inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row\n      corresponds to a single example.\n    kernel_size: int. Width and height to pool over per input channel. The\n      kernel is assumed to be square.\n    stride: int. Step size between pooling operations.\n\n  Returns:\n    Tensor of shape [num_examples, width/stride, height/stride, out_channels].\n    Result of applying max pooling to 'inputs'.\n  \"\"\"\n  # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.\n  with tf.variable_scope(\"pool_%d\" % layer_id):\n    return tf.nn.max_pool(\n        inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],\n        padding=\"SAME\",\n        name=\"pool\")\n\n\ndef build_model(examples,\n                labels,\n                num_labels,\n                layer_collection,\n                register_layers_manually=False):\n  \"\"\"Builds a ConvNet classification model.\n\n  Args:\n    examples: Tensor of shape [num_examples, num_features]. Represents inputs of\n      model.\n    labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted\n      by softmax for each example.\n    num_labels: int. Number of distinct values 'labels' can take on.\n    layer_collection: LayerCollection instance. Layers will be registered here.\n    register_layers_manually: bool. If True then register the layers with\n      layer_collection manually. (Default: False)\n\n  Returns:\n    loss: 0-D Tensor representing loss to be minimized.\n    accuracy: 0-D Tensor representing model's accuracy.\n  \"\"\"\n  # Build a ConvNet. For each layer with parameters, we'll keep track of the\n  # preactivations, activations, weights, and bias.\n  tf.logging.info(\"Building model.\")\n  pre0, act0, params0 = conv_layer(\n      layer_id=0, inputs=examples, kernel_size=5, out_channels=16)\n  act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)\n  pre2, act2, params2 = conv_layer(\n      layer_id=2, inputs=act1, kernel_size=5, out_channels=16)\n  act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)\n  flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])\n  logits, _, params4 = fc_layer(\n      layer_id=4, inputs=flat_act3, output_size=num_labels)\n  loss = tf.reduce_mean(\n      tf.nn.sparse_softmax_cross_entropy_with_logits(\n          labels=labels, logits=logits))\n  accuracy = tf.reduce_mean(\n      tf.cast(tf.equal(tf.cast(labels, dtype=tf.int32),\n                       tf.argmax(logits, axis=1, output_type=tf.int32)),\n              dtype=tf.float32))\n\n  with tf.device(\"/cpu:0\"):\n    tf.summary.scalar(\"loss\", loss)\n    tf.summary.scalar(\"accuracy\", accuracy)\n\n  layer_collection.register_softmax_cross_entropy_loss(\n      logits, name=\"logits\")\n\n  if register_layers_manually:\n    layer_collection.register_conv2d(params0, (1, 1, 1, 1), \"SAME\", examples,\n                                     pre0)\n    layer_collection.register_conv2d(params2, (1, 1, 1, 1), \"SAME\", act1,\n                                     pre2)\n    layer_collection.register_fully_connected(params4, flat_act3, logits)\n\n  return loss, accuracy\n\n\ndef minimize_loss_single_machine(loss,\n                                 accuracy,\n                                 layer_collection,\n                                 device=None,\n                                 session_config=None):\n  \"\"\"Minimize loss with K-FAC on a single machine.\n\n  Creates `PeriodicInvCovUpdateKfacOpt` which handles inverse and covariance\n  computation op placement and execution. A single Session is responsible for\n  running all of K-FAC's ops. The covariance and inverse update ops are placed\n  on `device`. All model variables are on CPU.\n\n  Args:\n    loss: 0-D Tensor. Loss to be minimized.\n    accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.\n    layer_collection: LayerCollection instance describing model architecture.\n      Used by K-FAC to construct preconditioner.\n    device: string or None. The covariance and inverse update ops are run on\n      this device. If empty or None, the default device will be used.\n      (Default: None)\n    session_config: None or tf.ConfigProto. Configuration for tf.Session().\n\n  Returns:\n    final value for 'accuracy'.\n  \"\"\"\n  device_list = [] if not device else [device]\n\n  # Train with K-FAC.\n  g_step = tf.train.get_or_create_global_step()\n  optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n      invert_every=_INVERT_EVERY,\n      cov_update_every=_COV_UPDATE_EVERY,\n      learning_rate=0.0001,\n      cov_ema_decay=0.95,\n      damping=0.001,\n      layer_collection=layer_collection,\n      placement_strategy=\"round_robin\",\n      cov_devices=device_list,\n      inv_devices=device_list,\n      trans_devices=device_list,\n      momentum=0.9)\n\n  with tf.device(device):\n    train_op = optimizer.minimize(loss, global_step=g_step)\n\n  tf.logging.info(\"Starting training.\")\n  with tf.train.MonitoredTrainingSession(config=session_config) as sess:\n    while not sess.should_stop():\n      global_step_, loss_, accuracy_, _ = sess.run(\n          [g_step, loss, accuracy, train_op])\n\n      if global_step_ % _REPORT_EVERY == 0:\n        tf.logging.info(\"global_step: %d | loss: %f | accuracy: %s\",\n                        global_step_, loss_, accuracy_)\n\n  return accuracy_\n\n\ndef minimize_loss_single_machine_manual(loss,\n                                        accuracy,\n                                        layer_collection,\n                                        device=None,\n                                        session_config=None):\n  \"\"\"Minimize loss with K-FAC on a single machine(Illustrative purpose only).\n\n  This function does inverse and covariance computation manually\n  for illustrative pupose. Check `minimize_loss_single_machine` for\n  automatic inverse and covariance op placement and execution.\n  A single Session is responsible for running all of K-FAC's ops. The covariance\n  and inverse update ops are placed on `device`. All model variables are on CPU.\n\n  Args:\n    loss: 0-D Tensor. Loss to be minimized.\n    accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.\n    layer_collection: LayerCollection instance describing model architecture.\n      Used by K-FAC to construct preconditioner.\n    device: string or None. The covariance and inverse update ops are run on\n      this device. If empty or None, the default device will be used.\n      (Default: None)\n    session_config: None or tf.ConfigProto. Configuration for tf.Session().\n\n  Returns:\n    final value for 'accuracy'.\n  \"\"\"\n  device_list = [] if not device else [device]\n\n  # Train with K-FAC.\n  g_step = tf.train.get_or_create_global_step()\n  optimizer = kfac.KfacOptimizer(\n      learning_rate=0.0001,\n      cov_ema_decay=0.95,\n      damping=0.001,\n      layer_collection=layer_collection,\n      placement_strategy=\"round_robin\",\n      cov_devices=device_list,\n      inv_devices=device_list,\n      trans_devices=device_list,\n      momentum=0.9)\n  (cov_update_thunks,\n   inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()\n\n  def make_update_op(update_thunks):\n    update_ops = [thunk() for thunk in update_thunks]\n    return tf.group(*update_ops)\n\n  cov_update_op = make_update_op(cov_update_thunks)\n  with tf.control_dependencies([cov_update_op]):\n    inverse_op = tf.cond(\n        tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),\n        lambda: make_update_op(inv_update_thunks), tf.no_op)\n    with tf.control_dependencies([inverse_op]):\n      with tf.device(device):\n        train_op = optimizer.minimize(loss, global_step=g_step)\n\n  tf.logging.info(\"Starting training.\")\n  with tf.train.MonitoredTrainingSession(config=session_config) as sess:\n    while not sess.should_stop():\n      global_step_, loss_, accuracy_, _ = sess.run(\n          [g_step, loss, accuracy, train_op])\n\n      if global_step_ % _REPORT_EVERY == 0:\n        tf.logging.info(\"global_step: %d | loss: %f | accuracy: %s\",\n                        global_step_, loss_, accuracy_)\n\n  return accuracy_\n\n\ndef _is_gradient_task(task_id, num_tasks):\n  \"\"\"Returns True if this task should update the weights.\"\"\"\n  if num_tasks < 3:\n    return True\n  return 0 <= task_id < 0.6 * num_tasks\n\n\ndef _is_cov_update_task(task_id, num_tasks):\n  \"\"\"Returns True if this task should update K-FAC's covariance matrices.\"\"\"\n  if num_tasks < 3:\n    return False\n  return 0.6 * num_tasks <= task_id < num_tasks - 1\n\n\ndef _is_inv_update_task(task_id, num_tasks):\n  \"\"\"Returns True if this task should update K-FAC's preconditioner.\"\"\"\n  if num_tasks < 3:\n    return False\n  return task_id == num_tasks - 1\n\n\ndef _num_gradient_tasks(num_tasks):\n  \"\"\"Number of tasks that will update weights.\"\"\"\n  if num_tasks < 3:\n    return num_tasks\n  return int(np.ceil(0.6 * num_tasks))\n\n\ndef _make_distributed_train_op(\n    task_id,\n    num_worker_tasks,\n    num_ps_tasks,\n    layer_collection\n):\n  \"\"\"Creates optimizer and distributed training op.\n\n  Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes\n  the train op.\n\n  Args:\n   task_id: int. Integer in [0, num_worker_tasks). ID for this worker.\n    num_worker_tasks: int. Number of workers in this distributed training setup.\n    num_ps_tasks: int. Number of parameter servers holding variables. If 0,\n      parameter servers are not used.\n    layer_collection: LayerCollection instance describing model architecture.\n      Used by K-FAC to construct preconditioner.\n\n  Returns:\n    sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC\n      optimizer.\n    optimizer: Instance of `KfacOptimizer`.\n    global_step: `tensor`, Global step.\n  \"\"\"\n  tf.logging.info(\"Task id : %d\", task_id)\n  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):\n    global_step = tf.train.get_or_create_global_step()\n    optimizer = kfac.KfacOptimizer(\n        learning_rate=0.0001,\n        cov_ema_decay=0.95,\n        damping=0.001,\n        layer_collection=layer_collection,\n        momentum=0.9)\n    sync_optimizer = tf.train.SyncReplicasOptimizer(\n        opt=optimizer,\n        replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),\n        total_num_replicas=num_worker_tasks)\n    return sync_optimizer, optimizer, global_step\n\n\ndef distributed_grads_only_and_ops_chief_worker(\n    task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,\n    loss, accuracy, layer_collection, invert_every=10):\n  \"\"\"Minimize loss with a synchronous implementation of K-FAC.\n\n  All workers perform gradient computation. Chief worker applies gradient after\n  averaging the gradients obtained from all the workers. All workers block\n  execution until the update is applied. Chief worker runs covariance and\n  inverse update ops. Covariance and inverse matrices are placed on parameter\n  servers in a round robin manner. For further details on synchronous\n  distributed optimization check `tf.train.SyncReplicasOptimizer`.\n\n  Args:\n    task_id: int. Integer in [0, num_worker_tasks). ID for this worker.\n    is_chief: `boolean`, `True` if the worker is chief worker.\n    num_worker_tasks: int. Number of workers in this distributed training setup.\n    num_ps_tasks: int. Number of parameter servers holding variables. If 0,\n      parameter servers are not used.\n    master: string. IP and port of TensorFlow runtime process. Set to empty\n      string to run locally.\n    checkpoint_dir: string or None. Path to store checkpoints under.\n    loss: 0-D Tensor. Loss to be minimized.\n    accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to\n      run with each step.\n    layer_collection: LayerCollection instance describing model architecture.\n      Used by K-FAC to construct preconditioner.\n    invert_every: `int`, Number of steps between update the inverse.\n\n  Returns:\n    final value for 'accuracy'.\n\n  Raises:\n    ValueError: if task_id >= num_worker_tasks.\n  \"\"\"\n\n  sync_optimizer, optimizer, global_step = _make_distributed_train_op(\n      task_id, num_worker_tasks, num_ps_tasks, layer_collection)\n  (cov_update_thunks,\n   inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()\n\n  tf.logging.info(\"Starting training.\")\n  hooks = [sync_optimizer.make_session_run_hook(is_chief)]\n\n  def make_update_op(update_thunks):\n    update_ops = [thunk() for thunk in update_thunks]\n    return tf.group(*update_ops)\n\n  if is_chief:\n    cov_update_op = make_update_op(cov_update_thunks)\n    with tf.control_dependencies([cov_update_op]):\n      inverse_op = tf.cond(\n          tf.equal(tf.mod(global_step, invert_every), 0),\n          lambda: make_update_op(inv_update_thunks),\n          tf.no_op)\n      with tf.control_dependencies([inverse_op]):\n        train_op = sync_optimizer.minimize(loss, global_step=global_step)\n  else:\n    train_op = sync_optimizer.minimize(loss, global_step=global_step)\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  config = tf.ConfigProto(allow_soft_placement=True)\n\n  with tf.train.MonitoredTrainingSession(\n      master=master,\n      is_chief=is_chief,\n      checkpoint_dir=checkpoint_dir,\n      hooks=hooks,\n      stop_grace_period_secs=0,\n      config=config) as sess:\n    while not sess.should_stop():\n      global_step_, loss_, accuracy_, _ = sess.run(\n          [global_step, loss, accuracy, train_op])\n      tf.logging.info(\"global_step: %d | loss: %f | accuracy: %s\", global_step_,\n                      loss_, accuracy_)\n  return accuracy_\n\n\ndef distributed_grads_and_ops_dedicated_workers(\n    task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,\n    loss, accuracy, layer_collection):\n  \"\"\"Minimize loss with a synchronous implementation of K-FAC.\n\n  Different workers are responsible for different parts of K-FAC's Ops. The\n  first 60% of tasks compute gradients; the next 20% accumulate covariance\n  statistics; the last 20% invert the matrices used to precondition gradients.\n  The chief worker computes and applies the update.\n\n  Args:\n    task_id: int. Integer in [0, num_worker_tasks). ID for this worker.\n    is_chief: `boolean`, `True` if the worker is chief worker.\n    num_worker_tasks: int. Number of workers in this distributed training setup.\n    num_ps_tasks: int. Number of parameter servers holding variables. If 0,\n      parameter servers are not used.\n    master: string. IP and port of TensorFlow runtime process. Set to empty\n      string to run locally.\n    checkpoint_dir: string or None. Path to store checkpoints under.\n    loss: 0-D Tensor. Loss to be minimized.\n    accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to\n      run with each step.\n    layer_collection: LayerCollection instance describing model architecture.\n      Used by K-FAC to construct preconditioner.\n\n  Returns:\n    final value for 'accuracy'.\n\n  Raises:\n    ValueError: if task_id >= num_worker_tasks.\n  \"\"\"\n  sync_optimizer, optimizer, global_step = _make_distributed_train_op(\n      task_id, num_worker_tasks, num_ps_tasks, layer_collection)\n  _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()\n  train_op = sync_optimizer.minimize(loss, global_step=global_step)\n  inv_update_queue = kfac.op_queue.OpQueue(inv_update_ops)\n\n  tf.logging.info(\"Starting training.\")\n  is_chief = (task_id == 0)\n  hooks = [sync_optimizer.make_session_run_hook(is_chief)]\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  config = tf.ConfigProto(allow_soft_placement=True)\n\n  with tf.train.MonitoredTrainingSession(\n      master=master,\n      is_chief=is_chief,\n      checkpoint_dir=checkpoint_dir,\n      hooks=hooks,\n      stop_grace_period_secs=0,\n      config=config) as sess:\n    while not sess.should_stop():\n      # Choose which op this task is responsible for running.\n      if _is_gradient_task(task_id, num_worker_tasks):\n        learning_op = train_op\n      elif _is_cov_update_task(task_id, num_worker_tasks):\n        learning_op = cov_update_op\n      elif _is_inv_update_task(task_id, num_worker_tasks):\n        learning_op = inv_update_queue.next_op(sess)\n      else:\n        raise ValueError(\"Which op should task %d do?\" % task_id)\n\n      global_step_, loss_, accuracy_, _ = sess.run(\n          [global_step, loss, accuracy, learning_op])\n      tf.logging.info(\"global_step: %d | loss: %f | accuracy: %s\", global_step_,\n                      loss_, accuracy_)\n\n  return accuracy_\n\n\ndef train_mnist_single_machine(num_epochs,\n                               use_fake_data=False,\n                               device=None,\n                               manual_op_exec=False):\n  \"\"\"Train a ConvNet on MNIST.\n\n  Args:\n    num_epochs: int. Number of passes to make over the training set.\n    use_fake_data: bool. If True, generate a synthetic dataset.\n    device: string or None. The covariance and inverse update ops are run on\n      this device. If empty or None, the default device will be used.\n      (Default: None)\n    manual_op_exec: bool, If `True` then `minimize_loss_single_machine_manual`\n      is called for training which handles inverse and covariance computation.\n      This is shown only for illustrative purpose. Otherwise\n      `minimize_loss_single_machine` is called which relies on\n      `PeriodicInvCovUpdateOpt` for op placement and execution.\n\n  Returns:\n    accuracy of model on the final minibatch of training data.\n  \"\"\"\n  # Load a dataset.\n  tf.logging.info(\"Loading MNIST into memory.\")\n  (examples, labels) = mnist.load_mnist_as_iterator(num_epochs,\n                                                    128,\n                                                    use_fake_data=use_fake_data,\n                                                    flatten_images=False)\n\n  # Build a ConvNet.\n  layer_collection = kfac.LayerCollection()\n  loss, accuracy = build_model(\n      examples, labels, num_labels=10, layer_collection=layer_collection,\n      register_layers_manually=_USE_MANUAL_REG)\n  if not _USE_MANUAL_REG:\n    layer_collection.auto_register_layers()\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  config = tf.ConfigProto(allow_soft_placement=True)\n\n  # Fit model.\n  if manual_op_exec:\n    return minimize_loss_single_machine_manual(\n        loss, accuracy, layer_collection, device=device, session_config=config)\n  else:\n    return minimize_loss_single_machine(\n        loss, accuracy, layer_collection, device=device, session_config=config)\n\n\ndef train_mnist_multitower(num_epochs, num_towers,\n                           devices, use_fake_data=False, session_config=None):\n  \"\"\"Train a ConvNet on MNIST.\n\n  Training data is split equally among the towers. Each tower computes loss on\n  its own batch of data and the loss is aggregated on the CPU. The model\n  variables are placed on first tower. The covariance and inverse update ops\n  and variables are placed on specified devices in a round robin manner.\n\n  Args:\n    num_epochs: int. Number of passes to make over the training set.\n    num_towers: int. Number of towers.\n    devices: list of strings. List of devices to place the towers.\n    use_fake_data: bool. If True, generate a synthetic dataset.\n    session_config: None or tf.ConfigProto. Configuration for tf.Session().\n\n  Returns:\n    accuracy of model on the final minibatch of training data.\n  \"\"\"\n  num_towers = 1 if not devices else len(devices)\n  # Load a dataset.\n  tf.logging.info(\"Loading MNIST into memory.\")\n  tower_batch_size = 128\n  batch_size = tower_batch_size * num_towers\n  tf.logging.info(\n      (\"Loading MNIST into memory. Using batch_size = %d = %d towers * %d \"\n       \"tower batch size.\") % (batch_size, num_towers, tower_batch_size))\n  (examples, labels) = mnist.load_mnist_as_iterator(num_epochs,\n                                                    batch_size,\n                                                    use_fake_data=use_fake_data,\n                                                    flatten_images=False)\n\n  # Split minibatch across towers.\n  examples = tf.split(examples, num_towers)\n  labels = tf.split(labels, num_towers)\n\n  # Build an MLP. Each tower's layers will be added to the LayerCollection.\n  layer_collection = kfac.LayerCollection()\n  tower_results = []\n  for tower_id in range(num_towers):\n    with tf.device(devices[tower_id]):\n      with tf.name_scope(\"tower%d\" % tower_id):\n        with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):\n          tf.logging.info(\"Building tower %d.\" % tower_id)\n          tower_results.append(\n              build_model(\n                  examples[tower_id],\n                  labels[tower_id],\n                  10,\n                  layer_collection,\n                  register_layers_manually=_USE_MANUAL_REG))\n  losses, accuracies = zip(*tower_results)\n  # When using multiple towers we only want to perform automatic\n  # registation once, after the final tower is made\n  if not _USE_MANUAL_REG:\n    layer_collection.auto_register_layers()\n\n  # Average across towers.\n  loss = tf.reduce_mean(losses)\n  accuracy = tf.reduce_mean(accuracies)\n\n  # Fit model.\n  g_step = tf.train.get_or_create_global_step()\n  optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n      invert_every=_INVERT_EVERY,\n      cov_update_every=_COV_UPDATE_EVERY,\n      learning_rate=0.0001,\n      cov_ema_decay=0.95,\n      damping=0.001,\n      layer_collection=layer_collection,\n      placement_strategy=\"round_robin\",\n      cov_devices=devices,\n      inv_devices=devices,\n      trans_devices=devices,\n      momentum=0.9)\n\n  with tf.device(devices[0]):\n    train_op = optimizer.minimize(loss, global_step=g_step)\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  if not session_config:\n    session_config = tf.ConfigProto(allow_soft_placement=True)\n\n  tf.logging.info(\"Starting training.\")\n  with tf.train.MonitoredTrainingSession(config=session_config) as sess:\n    while not sess.should_stop():\n      global_step_, loss_, accuracy_, _ = sess.run(\n          [g_step, loss, accuracy, train_op])\n\n      if global_step_ % _REPORT_EVERY == 0:\n        tf.logging.info(\"global_step: %d | loss: %f | accuracy: %s\",\n                        global_step_, loss_, accuracy_)\n\n\ndef train_mnist_distributed_sync_replicas(task_id,\n                                          is_chief,\n                                          num_worker_tasks,\n                                          num_ps_tasks,\n                                          master,\n                                          num_epochs,\n                                          op_strategy,\n                                          use_fake_data=False):\n  \"\"\"Train a ConvNet on MNIST using Sync replicas optimizer.\n\n  Args:\n    task_id: int. Integer in [0, num_worker_tasks). ID for this worker.\n    is_chief: `boolean`, `True` if the worker is chief worker.\n    num_worker_tasks: int. Number of workers in this distributed training setup.\n    num_ps_tasks: int. Number of parameter servers holding variables.\n    master: string. IP and port of TensorFlow runtime process.\n    num_epochs: int. Number of passes to make over the training set.\n    op_strategy: `string`, Strategy to run the covariance and inverse\n      ops. If op_strategy == `chief_worker` then covariance and inverse\n      update ops are run on chief worker otherwise they are run on dedicated\n      workers.\n\n    use_fake_data: bool. If True, generate a synthetic dataset.\n\n  Returns:\n    accuracy of model on the final minibatch of training data.\n\n  Raises:\n    ValueError: If `op_strategy` not in [\"chief_worker\", \"dedicated_workers\"].\n  \"\"\"\n  # Load a dataset.\n  tf.logging.info(\"Loading MNIST into memory.\")\n  (examples, labels) = mnist.load_mnist_as_iterator(num_epochs,\n                                                    128,\n                                                    use_fake_data=use_fake_data,\n                                                    flatten_images=False)\n\n  # Build a ConvNet.\n  layer_collection = kfac.LayerCollection()\n  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):\n    loss, accuracy = build_model(\n        examples, labels, num_labels=10, layer_collection=layer_collection,\n        register_layers_manually=_USE_MANUAL_REG)\n  if not _USE_MANUAL_REG:\n    layer_collection.auto_register_layers()\n\n  # Fit model.\n  checkpoint_dir = None\n  if op_strategy == \"chief_worker\":\n    return distributed_grads_only_and_ops_chief_worker(\n        task_id, is_chief, num_worker_tasks, num_ps_tasks, master,\n        checkpoint_dir, loss, accuracy, layer_collection)\n  elif op_strategy == \"dedicated_workers\":\n    return distributed_grads_and_ops_dedicated_workers(\n        task_id, is_chief, num_worker_tasks, num_ps_tasks, master,\n        checkpoint_dir, loss, accuracy, layer_collection)\n  else:\n    raise ValueError(\"Only supported op strategies are : {}, {}\".format(\n        \"chief_worker\", \"dedicated_workers\"))\n\n\ndef train_mnist_estimator(num_epochs, use_fake_data=False):\n  \"\"\"Train a ConvNet on MNIST using tf.estimator.\n\n  Args:\n    num_epochs: int. Number of passes to make over the training set.\n    use_fake_data: bool. If True, generate a synthetic dataset.\n\n  Returns:\n    accuracy of model on the final minibatch of training data.\n  \"\"\"\n\n  # Load a dataset.\n  def input_fn():\n    tf.logging.info(\"Loading MNIST into memory.\")\n    return mnist.load_mnist_as_iterator(num_epochs=num_epochs,\n                                        batch_size=64,\n                                        flatten_images=False,\n                                        use_fake_data=use_fake_data)\n\n  def model_fn(features, labels, mode, params):\n    \"\"\"Model function for MLP trained with K-FAC.\n\n    Args:\n      features: Tensor of shape [batch_size, input_size]. Input features.\n      labels: Tensor of shape [batch_size]. Target labels for training.\n      mode: tf.estimator.ModeKey. Must be TRAIN.\n      params: ignored.\n\n    Returns:\n      EstimatorSpec for training.\n\n    Raises:\n      ValueError: If 'mode' is anything other than TRAIN.\n    \"\"\"\n    del params\n\n    if mode != tf.estimator.ModeKeys.TRAIN:\n      raise ValueError(\"Only training is supported with this API.\")\n\n    # Build a ConvNet.\n    layer_collection = kfac.LayerCollection()\n    loss, accuracy = build_model(\n        features, labels, num_labels=10, layer_collection=layer_collection,\n        register_layers_manually=_USE_MANUAL_REG)\n    if not _USE_MANUAL_REG:\n      layer_collection.auto_register_layers()\n\n    # Train with K-FAC.\n    global_step = tf.train.get_or_create_global_step()\n    optimizer = kfac.KfacOptimizer(\n        learning_rate=tf.train.exponential_decay(\n            0.00002, global_step, 10000, 0.5, staircase=True),\n        cov_ema_decay=0.95,\n        damping=0.001,\n        layer_collection=layer_collection,\n        momentum=0.9)\n\n    (cov_update_thunks,\n     inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()\n\n    def make_update_op(update_thunks):\n      update_ops = [thunk() for thunk in update_thunks]\n      return tf.group(*update_ops)\n\n    def make_batch_executed_op(update_thunks, batch_size=1):\n      return tf.group(*kfac.utils.batch_execute(\n          global_step, update_thunks, batch_size=batch_size))\n\n    # Run cov_update_op every step. Run 1 inv_update_ops per step.\n    cov_update_op = make_update_op(cov_update_thunks)\n    with tf.control_dependencies([cov_update_op]):\n      # But make sure to execute all the inverse ops on the first step\n      inverse_op = tf.cond(tf.equal(global_step, 0),\n                           lambda: make_update_op(inv_update_thunks),\n                           lambda: make_batch_executed_op(inv_update_thunks))\n      with tf.control_dependencies([inverse_op]):\n        train_op = optimizer.minimize(loss, global_step=global_step)\n\n    # Print metrics every 5 sec.\n    hooks = [\n        tf.train.LoggingTensorHook(\n            {\n                \"loss\": loss,\n                \"accuracy\": accuracy\n            }, every_n_secs=5),\n    ]\n    return tf.estimator.EstimatorSpec(\n        mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)\n\n  run_config = tf.estimator.RunConfig(\n      model_dir=\"/tmp/mnist\", save_checkpoints_steps=1, keep_checkpoint_max=100)\n\n  # Train until input_fn() is empty with Estimator. This is a prerequisite for\n  # TPU compatibility.\n  estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)\n  estimator.train(input_fn=input_fn)\n"
  },
  {
    "path": "kfac/examples/keras/KFAC_vs_Adam_Experiment.md",
    "content": "# KFAC vs Adam Experiment\n\n## Set Up\n\nWe compare KFAC and Adam on a RESNET-20 on the CIFAR10 dataset. We split CIFAR10\ninto a training (40k), validation (10k), and test (10k) sets. We ran a random\nhyperparameter search where the best hyperparameters were chosen by the run that\nreaches 89% validation accuracy first in terms of number of steps. We decay both\nlearning rate and damping/epsilon exponentially. The final learning rate is\nfixed at 1e-4, final damping (KFAC) at 1e-6, and final epsilon (Adam) at 1e-8.\nBelow are the ranges of the tuned hyperparamters. The random search samples all\nthe hyperparameters from a log uniform scale:\n\n| Hyperparameter            | Min  | Max   |\n|---------------------------|------|-------|\n| Init Learning Rate        | 1e-2 | 10.0  |\n| Init Damping (KFAC)       | 1e-2 | 100.0 |\n| Init Epsilon (Adam)       | 1e-4 | 1.0   |\n| 1 - Learning Rate Decay   | 1e-4 | 0.1   |\n| 1 - Damping/Epsilon Decay | 1e-4 | 0.1   |\n| 1 - Momentum              | 1e-2 | 0.3   |\n\nThe initial tuning run was with seed 20190524 with the GPU training script on an\nNVIDIA Tesla P100. Then, after choosing the best hyperparameters, we ran those\nhyperparameters with the following 10 random seeds: 351515, 382980, 934126,\n891369, 64379, 402680, 672242, 421590, 498163, 448799.\n\n# Results\n\nThe chosen hyperparameters were the following (to 6 decimal places):\n\n| Hyperparameter            | KFAC     | Adam     |\n|---------------------------|----------|----------|\n| Init Learning Rate        | 0.227214 | 2.242663 |\n| Init Damping (KFAC)       | 0.288721 |          |\n| Init Epsilon (Adam)       |          | 0.183230 |\n| 1 - Learning Rate Decay   | 0.001090 | 0.000610 |\n| 1 - Damping/Epsilon Decay | 0.000287 | 0.000213 |\n| 1 - Momentum              | 0.018580 | 0.029656 |\n\n## Training Curves\n\nBelow are the loss and accuracy training curves with the training and test sets.\nThe line represents the mean of the 10 seed runs and the coloured region\nrepresents the bootstrapped standard deviation. KFAC reaches 89% validation\naccuracy at step 4640 and Adam at step 6560 (measurements were taken every 40\nsteps).\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_v_adam_loss_curve.png?raw=true)\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_v_adam_accuracy_curve.png?raw=true)\n\nAmong the other runs, KFAC decreases training loss quicker than Adam early in\ntraining, then show similar performance later in training.\n\n## Hyperparameter Analysis\n\nWe offer some analysis of the learning rate and damping for KFAC to aid in\nchoosing appropriate values for these hyperparameters. Plots with the rest of\nthe hyperparameters for both KFAC and Adam are in the plots folder.\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_lr_v_damping.png?raw=true)\n\nIn general, a higher learning rate requires a higher damping. A large learning\nrate with low damping leads to divergence, whereas a low learning rate with high\ndamping leads to SGD-like behaviour, which is suboptimal. The plot above shows\nlittle correlation due to the decay schedules playing a large role, which is\nshown below:\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_damping_v_dampingdecay.png?raw=true)\n\nA fast damping decay allows for faster training, but can easily lead to\ndivergence. The best runs are often close to diverging.\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_lr_v_lrdecay.png?raw=true)\n\nAs expected, a high learning rate with a low decay can lead to divergence.\n\n![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_lrdecay_v_dampingdecay.png?raw=true)\n\nJust like with the learning rate and damping, the learning rate decay should\nbe proportional the damping decay to prevent divergence while training quickly.\n"
  },
  {
    "path": "kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"_DDaAex5Q7u-\"\n      },\n      \"source\": [\n        \"##### Copyright 2019 The TensorFlow Authors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"W1dWWdNHQ9L0\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"_C170SDp6jBt\"\n      },\n      \"source\": [\n        \"# KFAC vs Adam on CIFAR10 on a GPU\\n\",\n        \"\\n\",\n        \"This notebook contains the code used to run the experiment comparing KFAC and Adam on CIFAR 10 with a Resnet-20. This was run on a NVIDIA Tesla P100 for the experiment. It can be run on a public GPU colab instance.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorflow/kfac/blob/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"rw0qz2RWkLeJ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!pip install kfac\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"LfGyhnaOsgYu\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"import tensorflow_datasets as tfds\\n\",\n        \"import math\\n\",\n        \"import kfac\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"DYWIY0C380ye\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"TRAINING_SIZE = 40000\\n\",\n        \"VALIDATION_SIZE = 10000\\n\",\n        \"TEST_SIZE = 10000\\n\",\n        \"SEED = 20190524\\n\",\n        \"\\n\",\n        \"num_training_steps = 7500\\n\",\n        \"batch_size = 1000\\n\",\n        \"layers = tf.keras.layers\\n\",\n        \"\\n\",\n        \"# We take the ceiling because we do not drop the remainder of the batch\\n\",\n        \"compute_steps_per_epoch = lambda x: int(math.ceil(1. * x / batch_size))\\n\",\n        \"steps_per_epoch = compute_steps_per_epoch(TRAINING_SIZE)\\n\",\n        \"val_steps = compute_steps_per_epoch(VALIDATION_SIZE)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"GfeTgsbh5G4g\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"optimizer_name = 'kfac'  # 'kfac' or 'adam'\\n\",\n        \"\\n\",\n        \"# Best Hyperparameters from the Random Search\\n\",\n        \"if optimizer_name == 'kfac':\\n\",\n        \"  init_learning_rate = 0.22721400059936694\\n\",\n        \"  final_learning_rate = 1e-04\\n\",\n        \"  init_damping = 0.28872127217018184\\n\",\n        \"  final_damping = 1e-6\\n\",\n        \"  momentum = 1 - 0.018580394981260295\\n\",\n        \"  lr_decay_rate = 1 - 0.001090107322908028\\n\",\n        \"  damping_decay_rate = 1 - 0.0002870880729016523\\n\",\n        \"elif optimizer_name == 'adam':\\n\",\n        \"  init_learning_rate = 2.24266320779\\n\",\n        \"  final_learning_rate = 1e-4\\n\",\n        \"  init_epsilon = 0.183230038808\\n\",\n        \"  final_epsilon = 1e-8\\n\",\n        \"  momentum = 1 - 0.0296561513388\\n\",\n        \"  lr_decay_rate = 1 - 0.000610416031571\\n\",\n        \"  epsilon_decay_rate = 1 - 0.000212682338199\\n\",\n        \"else:\\n\",\n        \"  raise ValueError('Ensure optimizer_name is kfac or adam')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"v3vSki-usp9k\"\n      },\n      \"source\": [\n        \"## Input Pipeline\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"D2U3i5kgssy_\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def _parse_fn(x):\\n\",\n        \"  image, label = x['image'], x['label']\\n\",\n        \"  image = tf.cast(image, tf.float32)\\n\",\n        \"  label = tf.cast(label, tf.int32)\\n\",\n        \"  image = image / 127.5 - 1\\n\",\n        \"  return image, label\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def _augment_image(image, crop_amount, seed=None):\\n\",\n        \"  # Random Brightness, Contrast, Jpeg Quality, Hue, and Saturation did not\\n\",\n        \"  # seem to work well as augmentations for our training specifications\\n\",\n        \"  input_shape = image.shape.as_list()\\n\",\n        \"  cropped_size = [input_shape[0] - crop_amount,\\n\",\n        \"                  input_shape[1] - crop_amount,\\n\",\n        \"                  input_shape[2]]\\n\",\n        \"  flipped = tf.image.random_flip_left_right(image, seed)\\n\",\n        \"  cropped = tf.image.random_crop(flipped, cropped_size, seed)\\n\",\n        \"  return tf.image.pad_to_bounding_box(image=cropped,\\n\",\n        \"                                      offset_height=crop_amount // 2,\\n\",\n        \"                                      offset_width=crop_amount // 2,\\n\",\n        \"                                      target_height=input_shape[0],\\n\",\n        \"                                      target_width=input_shape[1])\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def _get_raw_data():\\n\",\n        \"  # We split the training data into training and validation ourselves for\\n\",\n        \"  # hyperparameter tuning.\\n\",\n        \"  training_pct = int(100.0 * TRAINING_SIZE / (TRAINING_SIZE + VALIDATION_SIZE))\\n\",\n        \"  train_split = tfds.Split.TRAIN.subsplit(tfds.percent[:training_pct])\\n\",\n        \"  validation_split = tfds.Split.TRAIN.subsplit(tfds.percent[training_pct:])\\n\",\n        \"\\n\",\n        \"  train_data, info = tfds.load('cifar10:3.*.*', with_info=True, split=train_split)\\n\",\n        \"  val_data = tfds.load('cifar10:3.*.*', split=validation_split)\\n\",\n        \"  test_data = tfds.load('cifar10:3.*.*', split='test')\\n\",\n        \"\\n\",\n        \"  input_shape = info.features['image'].shape\\n\",\n        \"  num_classes = info.features['label'].num_classes\\n\",\n        \"  info = {'input_shape': input_shape, 'num_classes': num_classes}\\n\",\n        \"  return info, train_data, val_data, test_data\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def get_input_pipeline(batch_size=None,\\n\",\n        \"                       use_augmentation=True,\\n\",\n        \"                       seed=None,\\n\",\n        \"                       crop_amount=6,\\n\",\n        \"                       drop_remainder=False,\\n\",\n        \"                       repeat_validation=True):\\n\",\n        \"  \\\"\\\"\\\"Creates CIFAR10 Data Pipeline.\\n\",\n        \"\\n\",\n        \"  Args:\\n\",\n        \"    batch_size (int): Batch size used for training.\\n\",\n        \"    use_augmentation (bool): If true, applies random horizontal flips and crops\\n\",\n        \"      then pads to images.\\n\",\n        \"    seed (int): Random seed used for augmentation operations.\\n\",\n        \"    crop_amount (int): Number of pixels to crop from the height and width of the\\n\",\n        \"      image. So, the cropped image will be [height - crop_amount, width -\\n\",\n        \"      crop_amount, channels] before it is padded to restore its original size.\\n\",\n        \"    drop_remainder (bool): Whether to drop the remainder of the batch. Needs to\\n\",\n        \"      be true to work on TPUs.\\n\",\n        \"    repeat_validation (bool): Whether to repeat the validation set. Test set is\\n\",\n        \"      never repeated.\\n\",\n        \"\\n\",\n        \"  Returns:\\n\",\n        \"    A tuple with an info dict (with input_shape (tuple) and number of classes\\n\",\n        \"    (int)) and data dict (train_data (tf.DatasetAdapter), validation_data,\\n\",\n        \"    (tf.DatasetAdapter) and test_data (tf.DatasetAdapter))\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  info, train_data, val_data, test_data = _get_raw_data()\\n\",\n        \"\\n\",\n        \"  if not batch_size:\\n\",\n        \"    batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)\\n\",\n        \"\\n\",\n        \"  train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()\\n\",\n        \"  if use_augmentation:\\n\",\n        \"    train_data = train_data.map(\\n\",\n        \"        lambda x, y: (_augment_image(x, crop_amount, seed), y))\\n\",\n        \"  train_data = train_data.batch(\\n\",\n        \"      min(batch_size, TRAINING_SIZE), drop_remainder=drop_remainder)\\n\",\n        \"  train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\\n\",\n        \"\\n\",\n        \"  val_data = val_data.map(_parse_fn)\\n\",\n        \"  if repeat_validation:\\n\",\n        \"    val_data = val_data.repeat()\\n\",\n        \"  val_data = val_data.batch(\\n\",\n        \"      min(batch_size, VALIDATION_SIZE), drop_remainder=drop_remainder)\\n\",\n        \"  val_data = val_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\\n\",\n        \"\\n\",\n        \"  # Don't repeat test data because it is only used once to evaluate at the end.\\n\",\n        \"  test_data = test_data.map(_parse_fn)\\n\",\n        \"  if repeat_validation:\\n\",\n        \"    test_data = test_data.repeat()\\n\",\n        \"  test_data = test_data.batch(\\n\",\n        \"      min(batch_size, TEST_SIZE), drop_remainder=drop_remainder)\\n\",\n        \"  test_data = test_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\\n\",\n        \"\\n\",\n        \"  data = {'train': train_data, 'validation': val_data, 'test': test_data}\\n\",\n        \"  return data, info\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"SLvlpsups2aR\"\n      },\n      \"source\": [\n        \"## Model - Resnet V2\\n\",\n        \"\\n\",\n        \"Based on https://keras.io/examples/cifar10_resnet/. The only difference is that tf.keras layer implementations are used.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"Cch3Ld5Ds4i2\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def resnet_layer(inputs,\\n\",\n        \"                 num_filters=16,\\n\",\n        \"                 kernel_size=3,\\n\",\n        \"                 strides=1,\\n\",\n        \"                 activation='relu',\\n\",\n        \"                 batch_normalization=True,\\n\",\n        \"                 conv_first=True):\\n\",\n        \"  \\\"\\\"\\\"2D Convolution-Batch Normalization-Activation stack builder.\\n\",\n        \"\\n\",\n        \"  Based on https://keras.io/examples/cifar10_resnet/.\\n\",\n        \"\\n\",\n        \"  Args:\\n\",\n        \"    inputs (tensor): input tensor from input image or previous layer\\n\",\n        \"    num_filters (int): Conv2D number of filters\\n\",\n        \"    kernel_size (int): Conv2D square kernel dimensions\\n\",\n        \"    strides (int): Conv2D square stride dimensions\\n\",\n        \"    activation (string): activation name\\n\",\n        \"    batch_normalization (bool): whether to include batch normalization\\n\",\n        \"    conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)\\n\",\n        \"\\n\",\n        \"  Returns:\\n\",\n        \"    x (tensor): tensor as input to the next layer\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  conv = layers.Conv2D(num_filters,\\n\",\n        \"                       kernel_size=kernel_size,\\n\",\n        \"                       strides=strides,\\n\",\n        \"                       padding='same',\\n\",\n        \"                       kernel_initializer='he_normal',\\n\",\n        \"                       kernel_regularizer=tf.keras.regularizers.l2(1e-4))\\n\",\n        \"\\n\",\n        \"  x = inputs\\n\",\n        \"  if conv_first:\\n\",\n        \"    x = conv(x)\\n\",\n        \"    if batch_normalization:\\n\",\n        \"      x = layers.BatchNormalization()(x)\\n\",\n        \"    if activation is not None:\\n\",\n        \"      x = layers.Activation(activation)(x)\\n\",\n        \"  else:\\n\",\n        \"    if batch_normalization:\\n\",\n        \"      x = layers.BatchNormalization()(x)\\n\",\n        \"    if activation is not None:\\n\",\n        \"      x = layers.Activation(activation)(x)\\n\",\n        \"    x = conv(x)\\n\",\n        \"  return x\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def resnet_v2(input_shape, depth, num_classes=10):\\n\",\n        \"  \\\"\\\"\\\"ResNet Version 2 Model builder [b].\\n\",\n        \"\\n\",\n        \"    Based on https://keras.io/examples/cifar10_resnet/.\\n\",\n        \"\\n\",\n        \"    Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as\\n\",\n        \"    bottleneck layer\\n\",\n        \"    First shortcut connection per layer is 1 x 1 Conv2D.\\n\",\n        \"    Second and onwards shortcut connection is identity.\\n\",\n        \"    At the beginning of each stage, the feature map size is halved (downsampled)\\n\",\n        \"    by a convolutional layer with strides=2, while the number of filter maps is\\n\",\n        \"    doubled. Within each stage, the layers have the same number filters and the\\n\",\n        \"    same filter map sizes.\\n\",\n        \"    Features maps sizes:\\n\",\n        \"    conv1  : 32x32,  16\\n\",\n        \"    stage 0: 32x32,  64\\n\",\n        \"    stage 1: 16x16, 128\\n\",\n        \"    stage 2:  8x8,  256\\n\",\n        \"\\n\",\n        \"    Args:\\n\",\n        \"      input_shape (tuple/list): shape of input image tensor\\n\",\n        \"      depth (int): number of core convolutional layers\\n\",\n        \"      num_classes (int): number of classes (CIFAR10 has 10)\\n\",\n        \"\\n\",\n        \"    Returns:\\n\",\n        \"      model (Model): Keras model instance\\n\",\n        \"    \\\"\\\"\\\"\\n\",\n        \"  if (depth - 2) % 9 != 0:\\n\",\n        \"    raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\\n\",\n        \"  # Start model definition.\\n\",\n        \"  num_filters_in = 16\\n\",\n        \"  num_res_blocks = int((depth - 2) / 9)\\n\",\n        \"\\n\",\n        \"  inputs = tf.keras.Input(shape=input_shape)\\n\",\n        \"  # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\\n\",\n        \"  x = resnet_layer(inputs=inputs, num_filters=num_filters_in, conv_first=True)\\n\",\n        \"\\n\",\n        \"  # Instantiate the stack of residual units\\n\",\n        \"  for stage in range(3):\\n\",\n        \"    for res_block in range(num_res_blocks):\\n\",\n        \"      activation = 'relu'\\n\",\n        \"      batch_normalization = True\\n\",\n        \"      strides = 1\\n\",\n        \"      if stage == 0:\\n\",\n        \"        num_filters_out = num_filters_in * 4\\n\",\n        \"        if res_block == 0:  # first layer and first stage\\n\",\n        \"          activation = None\\n\",\n        \"          batch_normalization = False\\n\",\n        \"      else:\\n\",\n        \"        num_filters_out = num_filters_in * 2\\n\",\n        \"        if res_block == 0:  # first layer but not first stage\\n\",\n        \"          strides = 2  # downsample\\n\",\n        \"\\n\",\n        \"      # bottleneck residual unit\\n\",\n        \"      y = resnet_layer(inputs=x,\\n\",\n        \"                       num_filters=num_filters_in,\\n\",\n        \"                       kernel_size=1,\\n\",\n        \"                       strides=strides,\\n\",\n        \"                       activation=activation,\\n\",\n        \"                       batch_normalization=batch_normalization,\\n\",\n        \"                       conv_first=False)\\n\",\n        \"      y = resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\\n\",\n        \"      y = resnet_layer(inputs=y,\\n\",\n        \"                       num_filters=num_filters_out,\\n\",\n        \"                       kernel_size=1,\\n\",\n        \"                       conv_first=False)\\n\",\n        \"      if res_block == 0:\\n\",\n        \"        # linear projection residual shortcut connection to match\\n\",\n        \"        # changed dims\\n\",\n        \"        x = resnet_layer(inputs=x,\\n\",\n        \"                         num_filters=num_filters_out,\\n\",\n        \"                         kernel_size=1,\\n\",\n        \"                         strides=strides,\\n\",\n        \"                         activation=None,\\n\",\n        \"                         batch_normalization=False)\\n\",\n        \"      x = layers.Add()([x, y])\\n\",\n        \"\\n\",\n        \"    num_filters_in = num_filters_out\\n\",\n        \"\\n\",\n        \"  # Add classifier on top.\\n\",\n        \"  # v2 has BN-ReLU before Pooling\\n\",\n        \"  x = layers.BatchNormalization()(x)\\n\",\n        \"  x = layers.Activation('relu')(x)\\n\",\n        \"  x = layers.AveragePooling2D(pool_size=8)(x)\\n\",\n        \"  y = layers.Flatten()(x)\\n\",\n        \"  outputs = layers.Dense(num_classes,\\n\",\n        \"                         activation='softmax',\\n\",\n        \"                         kernel_initializer='he_normal')(y)\\n\",\n        \"\\n\",\n        \"  # Instantiate model.\\n\",\n        \"  model = tf.keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"  return model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"dAUaN-i9tHMY\"\n      },\n      \"source\": [\n        \"## Training\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"Hf5WFHYP8tT9\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"tf.reset_default_graph()\\n\",\n        \"tf.set_random_seed(SEED)\\n\",\n        \"\\n\",\n        \"data, info = get_input_pipeline(batch_size=batch_size,\\n\",\n        \"                                seed=SEED,\\n\",\n        \"                                repeat_validation=True,\\n\",\n        \"                                use_augmentation=True)\\n\",\n        \"\\n\",\n        \"model = resnet_v2(input_shape=info['input_shape'],\\n\",\n        \"                  depth=20,\\n\",\n        \"                  num_classes=info['num_classes'])\\n\",\n        \"\\n\",\n        \"loss = 'sparse_categorical_crossentropy'\\n\",\n        \"\\n\",\n        \"training_callbacks = [\\n\",\n        \"    kfac.keras.callbacks.ExponentialDecay(hyperparameter='learning_rate',\\n\",\n        \"                                          init_value=init_learning_rate,\\n\",\n        \"                                          final_value=final_learning_rate,\\n\",\n        \"                                          decay_rate=lr_decay_rate)\\n\",\n        \"]\\n\",\n        \"\\n\",\n        \"if optimizer_name == 'kfac':\\n\",\n        \"  opt = kfac.keras.optimizers.Kfac(learning_rate=init_learning_rate,\\n\",\n        \"                                   damping=init_damping,\\n\",\n        \"                                   model=model,\\n\",\n        \"                                   loss=loss,\\n\",\n        \"                                   momentum=momentum,\\n\",\n        \"                                   seed=SEED)\\n\",\n        \"  training_callbacks.append(kfac.keras.callbacks.ExponentialDecay(\\n\",\n        \"      hyperparameter='damping',\\n\",\n        \"      init_value=init_damping,\\n\",\n        \"      final_value=final_damping,\\n\",\n        \"      decay_rate=damping_decay_rate))\\n\",\n        \"\\n\",\n        \"elif optimizer_name == 'adam':\\n\",\n        \"  opt = tf.keras.optimizers.Adam(learning_rate=init_learning_rate,\\n\",\n        \"                                 beta_1=momentum,\\n\",\n        \"                                 epsilon=init_epsilon)\\n\",\n        \"  training_callbacks.append(kfac.keras.callbacks.ExponentialDecay(\\n\",\n        \"      hyperparameter='epsilon',\\n\",\n        \"      init_value=init_epsilon,\\n\",\n        \"      final_value=final_epsilon,\\n\",\n        \"      decay_rate=epsilon_decay_rate))\\n\",\n        \"\\n\",\n        \"else:\\n\",\n        \"  raise ValueError('optimizer_name must be \\\"adam\\\" or \\\"kfac\\\"')\\n\",\n        \"\\n\",\n        \"model.compile(loss=loss, optimizer=opt, metrics=['acc'])\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"dD8b27hLy6lO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"history = model.fit(x=data['train'],\\n\",\n        \"                    epochs=num_training_steps//steps_per_epoch,\\n\",\n        \"                    steps_per_epoch=steps_per_epoch,\\n\",\n        \"                    validation_data=data['validation'],\\n\",\n        \"                    validation_steps=val_steps,\\n\",\n        \"                    callbacks=training_callbacks)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [\n        \"_DDaAex5Q7u-\"\n      ],\n      \"last_runtime\": {\n        \"build_target\": \"\",\n        \"kind\": \"local\"\n      },\n      \"name\": \"KFAC vs Adam on CIFAR10.ipynb\",\n      \"provenance\": [\n        {\n          \"file_id\": \"1pqtoYduODZyJKt4-kwVkt_KtNQCnaNDp\",\n          \"timestamp\": 1565229994386\n        }\n      ],\n      \"version\": \"0.3.2\"\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 2\",\n      \"name\": \"python2\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"_DDaAex5Q7u-\"\n      },\n      \"source\": [\n        \"##### Copyright 2019 The TensorFlow Authors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"W1dWWdNHQ9L0\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"KDGkOqGA54FB\"\n      },\n      \"source\": [\n        \"# KFAC vs Adam on CIFAR10 on TPUs\\n\",\n        \"\\n\",\n        \"This notebook demonstrates how to write a custom training loop with TPU Strategy with KFAC. It can be run on a public TPU colab instance. The key differences between using KFAC with TPU Strategy and using TPU Strategy normally are that the model and optimizer must be created in your train step and KFAC does not work with model.fit (because of the first condition). We also use a batch_size of 1024 instead of 1000 to better utilize TPUs.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorflow/kfac/blob/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"At1AvF75kmlr\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!pip install kfac\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"LfGyhnaOsgYu\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"import math\\n\",\n        \"import kfac\\n\",\n        \"import os\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"DYWIY0C380ye\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"TRAINING_SIZE = 40000\\n\",\n        \"VALIDATION_SIZE = 10000\\n\",\n        \"TEST_SIZE = 10000\\n\",\n        \"SEED = 20190524\\n\",\n        \"\\n\",\n        \"num_training_steps = 7500\\n\",\n        \"# We use a batch size of 1024 instead 1000 because each TPU core should\\n\",\n        \"# (ideally) get a batch whose size is a multiple 128 (here we have 8 cores)\\n\",\n        \"batch_size = 1024\\n\",\n        \"layers = tf.keras.layers\\n\",\n        \"\\n\",\n        \"compute_steps_per_epoch = lambda x: int(math.floor(1. * x / batch_size))\\n\",\n        \"steps_per_epoch = compute_steps_per_epoch(TRAINING_SIZE)\\n\",\n        \"val_steps = compute_steps_per_epoch(VALIDATION_SIZE)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"GfeTgsbh5G4g\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"optimizer_name = 'kfac'  # 'kfac' or 'adam'\\n\",\n        \"\\n\",\n        \"# Best Hyperparameters from the Random Search\\n\",\n        \"if optimizer_name == 'kfac':\\n\",\n        \"  init_learning_rate = 0.22721400059936694\\n\",\n        \"  final_learning_rate = 1e-04\\n\",\n        \"  init_damping = 0.28872127217018184\\n\",\n        \"  final_damping = 1e-6\\n\",\n        \"  momentum = 1 - 0.018580394981260295\\n\",\n        \"  lr_decay_rate = 1 - 0.001090107322908028\\n\",\n        \"  damping_decay_rate = 1 - 0.0002870880729016523\\n\",\n        \"elif optimizer_name == 'adam':\\n\",\n        \"  init_learning_rate = 2.24266320779\\n\",\n        \"  final_learning_rate = 1e-4\\n\",\n        \"  init_epsilon = 0.183230038808\\n\",\n        \"  final_epsilon = 1e-8\\n\",\n        \"  momentum = 1 - 0.0296561513388\\n\",\n        \"  lr_decay_rate = 1 - 0.000610416031571\\n\",\n        \"  epsilon_decay_rate = 1 - 0.000212682338199\\n\",\n        \"else:\\n\",\n        \"  raise ValueError('Ensure optimizer_name is kfac or adam')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"v3vSki-usp9k\"\n      },\n      \"source\": [\n        \"## Input Pipeline\\n\",\n        \"\\n\",\n        \"The tensorflow_datasets CIFAR10 dataset (used in the GPU notebook) does not work with the public TPUs, because of the way the tf.data.Dataset is downloaded. So, this pipeline uses the tf.keras version, which downloads numpy arrays that we turn into a tf.data.Dataset.\\n\",\n        \"\\n\",\n        \"If this pipeline does not work, try using the pipeline in the GPU notebook instead.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"D2U3i5kgssy_\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def _parse_fn(image, label):\\n\",\n        \"  image = tf.cast(image, tf.float32)\\n\",\n        \"  label = tf.cast(tf.squeeze(label), tf.int32)\\n\",\n        \"  image = image / 127.5 - 1\\n\",\n        \"  return image, label\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def _augment_image(image, crop_amount, seed=None):\\n\",\n        \"  # Random Brightness, Contrast, Jpeg Quality, Hue, and Saturation did not\\n\",\n        \"  # seem to work well as augmentations for our training specifications\\n\",\n        \"  input_shape = image.shape.as_list()\\n\",\n        \"  cropped_size = [input_shape[0] - crop_amount,\\n\",\n        \"                  input_shape[1] - crop_amount,\\n\",\n        \"                  input_shape[2]]\\n\",\n        \"  flipped = tf.image.random_flip_left_right(image, seed)\\n\",\n        \"  cropped = tf.image.random_crop(flipped, cropped_size, seed)\\n\",\n        \"  return tf.image.pad_to_bounding_box(image=cropped,\\n\",\n        \"                                      offset_height=crop_amount // 2,\\n\",\n        \"                                      offset_width=crop_amount // 2,\\n\",\n        \"                                      target_height=input_shape[0],\\n\",\n        \"                                      target_width=input_shape[1])\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def _get_raw_data():\\n\",\n        \"  # We split the training data into training and validation ourselves for\\n\",\n        \"  # hyperparameter tuning.\\n\",\n        \"  train_and_val, test = tf.keras.datasets.cifar10.load_data()\\n\",\n        \"  train = (train_and_val[0][:TRAINING_SIZE], train_and_val[1][:TRAINING_SIZE])\\n\",\n        \"  val = (train_and_val[0][TRAINING_SIZE:], train_and_val[1][TRAINING_SIZE:])\\n\",\n        \"  info = {'input_shape':train_and_val[0].shape[1:], 'num_classes':10}\\n\",\n        \"  return (info,\\n\",\n        \"          tf.data.Dataset.from_tensor_slices(train),\\n\",\n        \"          tf.data.Dataset.from_tensor_slices(val),\\n\",\n        \"          tf.data.Dataset.from_tensor_slices(test))\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def get_input_pipeline(batch_size=None,\\n\",\n        \"                       use_augmentation=True,\\n\",\n        \"                       seed=None,\\n\",\n        \"                       crop_amount=6,\\n\",\n        \"                       drop_remainder=False,\\n\",\n        \"                       repeat_validation=True):\\n\",\n        \"  \\\"\\\"\\\"Creates CIFAR10 Data Pipeline.\\n\",\n        \"\\n\",\n        \"  Args:\\n\",\n        \"    batch_size (int): Batch size used for training.\\n\",\n        \"    use_augmentation (bool): If true, applies random horizontal flips and crops\\n\",\n        \"      then pads to images.\\n\",\n        \"    seed (int): Random seed used for augmentation operations.\\n\",\n        \"    crop_amount (int): Number of pixels to crop from the height and width of the\\n\",\n        \"      image. So, the cropped image will be [height - crop_amount, width -\\n\",\n        \"      crop_amount, channels] before it is padded to restore its original size.\\n\",\n        \"    drop_remainder (bool): Whether to drop the remainder of the batch. Needs to\\n\",\n        \"      be true to work on TPUs.\\n\",\n        \"    repeat_validation (bool): Whether to repeat the validation set. Test set is\\n\",\n        \"      never repeated.\\n\",\n        \"\\n\",\n        \"  Returns:\\n\",\n        \"    A tuple with an info dict (with input_shape (tuple) and number of classes\\n\",\n        \"    (int)) and data dict (train_data (tf.DatasetAdapter), validation_data,\\n\",\n        \"    (tf.DatasetAdapter) and test_data (tf.DatasetAdapter))\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  info, train_data, val_data, test_data = _get_raw_data()\\n\",\n        \"\\n\",\n        \"  if not batch_size:\\n\",\n        \"    batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)\\n\",\n        \"\\n\",\n        \"  train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()\\n\",\n        \"  if use_augmentation:\\n\",\n        \"    train_data = train_data.map(\\n\",\n        \"        lambda x, y: (_augment_image(x, crop_amount, seed), y))\\n\",\n        \"  train_data = train_data.batch(\\n\",\n        \"      min(batch_size, TRAINING_SIZE), drop_remainder=drop_remainder)\\n\",\n        \"  train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\\n\",\n        \"\\n\",\n        \"  val_data = val_data.map(_parse_fn)\\n\",\n        \"  if repeat_validation:\\n\",\n        \"    val_data = val_data.repeat()\\n\",\n        \"  val_data = val_data.batch(\\n\",\n        \"      min(batch_size, VALIDATION_SIZE), drop_remainder=drop_remainder)\\n\",\n        \"  val_data = val_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\\n\",\n        \"\\n\",\n        \"  # Don't repeat test data because it is only used once to evaluate at the end.\\n\",\n        \"  test_data = test_data.map(_parse_fn)\\n\",\n        \"  if repeat_validation:\\n\",\n        \"    test_data = test_data.repeat()\\n\",\n        \"  test_data = test_data.batch(\\n\",\n        \"      min(batch_size, TEST_SIZE), drop_remainder=drop_remainder)\\n\",\n        \"  test_data = test_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\\n\",\n        \"\\n\",\n        \"  data = {'train': train_data, 'validation': val_data, 'test': test_data}\\n\",\n        \"  return data, info\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"SLvlpsups2aR\"\n      },\n      \"source\": [\n        \"## Model - Resnet V2\\n\",\n        \"\\n\",\n        \"Based on https://keras.io/examples/cifar10_resnet/. The only difference is that tf.keras layer implementations are used, the model outputs logits instead of a probability distribution, and an input tensor is used instead of the input shape (since TPUs don't support placeholders).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"Cch3Ld5Ds4i2\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def resnet_layer(inputs,\\n\",\n        \"                 num_filters=16,\\n\",\n        \"                 kernel_size=3,\\n\",\n        \"                 strides=1,\\n\",\n        \"                 activation='relu',\\n\",\n        \"                 batch_normalization=True,\\n\",\n        \"                 conv_first=True):\\n\",\n        \"  \\\"\\\"\\\"2D Convolution-Batch Normalization-Activation stack builder.\\n\",\n        \"\\n\",\n        \"  Based on https://keras.io/examples/cifar10_resnet/.\\n\",\n        \"\\n\",\n        \"  Args:\\n\",\n        \"    inputs (tensor): input tensor from input image or previous layer\\n\",\n        \"    num_filters (int): Conv2D number of filters\\n\",\n        \"    kernel_size (int): Conv2D square kernel dimensions\\n\",\n        \"    strides (int): Conv2D square stride dimensions\\n\",\n        \"    activation (string): activation name\\n\",\n        \"    batch_normalization (bool): whether to include batch normalization\\n\",\n        \"    conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)\\n\",\n        \"\\n\",\n        \"  Returns:\\n\",\n        \"    x (tensor): tensor as input to the next layer\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  conv = layers.Conv2D(num_filters,\\n\",\n        \"                       kernel_size=kernel_size,\\n\",\n        \"                       strides=strides,\\n\",\n        \"                       padding='same',\\n\",\n        \"                       kernel_initializer='he_normal',\\n\",\n        \"                       kernel_regularizer=tf.keras.regularizers.l2(1e-4))\\n\",\n        \"\\n\",\n        \"  x = inputs\\n\",\n        \"  if conv_first:\\n\",\n        \"    x = conv(x)\\n\",\n        \"    if batch_normalization:\\n\",\n        \"      x = layers.BatchNormalization()(x)\\n\",\n        \"    if activation is not None:\\n\",\n        \"      x = layers.Activation(activation)(x)\\n\",\n        \"  else:\\n\",\n        \"    if batch_normalization:\\n\",\n        \"      x = layers.BatchNormalization()(x)\\n\",\n        \"    if activation is not None:\\n\",\n        \"      x = layers.Activation(activation)(x)\\n\",\n        \"    x = conv(x)\\n\",\n        \"  return x\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def resnet_v2(input_tensor, depth, num_classes=10):\\n\",\n        \"  \\\"\\\"\\\"ResNet Version 2 Model builder [b].\\n\",\n        \"\\n\",\n        \"    Based on https://keras.io/examples/cifar10_resnet/.\\n\",\n        \"\\n\",\n        \"    Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as\\n\",\n        \"    bottleneck layer\\n\",\n        \"    First shortcut connection per layer is 1 x 1 Conv2D.\\n\",\n        \"    Second and onwards shortcut connection is identity.\\n\",\n        \"    At the beginning of each stage, the feature map size is halved (downsampled)\\n\",\n        \"    by a convolutional layer with strides=2, while the number of filter maps is\\n\",\n        \"    doubled. Within each stage, the layers have the same number filters and the\\n\",\n        \"    same filter map sizes.\\n\",\n        \"    Features maps sizes:\\n\",\n        \"    conv1  : 32x32,  16\\n\",\n        \"    stage 0: 32x32,  64\\n\",\n        \"    stage 1: 16x16, 128\\n\",\n        \"    stage 2:  8x8,  256\\n\",\n        \"\\n\",\n        \"    Args:\\n\",\n        \"      input_shape (tuple/list): shape of input image tensor\\n\",\n        \"      depth (int): number of core convolutional layers\\n\",\n        \"      num_classes (int): number of classes (CIFAR10 has 10)\\n\",\n        \"\\n\",\n        \"    Returns:\\n\",\n        \"      model (Model): Keras model instance\\n\",\n        \"    \\\"\\\"\\\"\\n\",\n        \"  if (depth - 2) % 9 != 0:\\n\",\n        \"    raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\\n\",\n        \"  # Start model definition.\\n\",\n        \"  num_filters_in = 16\\n\",\n        \"  num_res_blocks = int((depth - 2) / 9)\\n\",\n        \"\\n\",\n        \"  inputs = tf.keras.Input(tensor=input_tensor)\\n\",\n        \"  # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\\n\",\n        \"  x = resnet_layer(inputs=inputs, num_filters=num_filters_in, conv_first=True)\\n\",\n        \"\\n\",\n        \"  # Instantiate the stack of residual units\\n\",\n        \"  for stage in range(3):\\n\",\n        \"    for res_block in range(num_res_blocks):\\n\",\n        \"      activation = 'relu'\\n\",\n        \"      batch_normalization = True\\n\",\n        \"      strides = 1\\n\",\n        \"      if stage == 0:\\n\",\n        \"        num_filters_out = num_filters_in * 4\\n\",\n        \"        if res_block == 0:  # first layer and first stage\\n\",\n        \"          activation = None\\n\",\n        \"          batch_normalization = False\\n\",\n        \"      else:\\n\",\n        \"        num_filters_out = num_filters_in * 2\\n\",\n        \"        if res_block == 0:  # first layer but not first stage\\n\",\n        \"          strides = 2  # downsample\\n\",\n        \"\\n\",\n        \"      # bottleneck residual unit\\n\",\n        \"      y = resnet_layer(inputs=x,\\n\",\n        \"                       num_filters=num_filters_in,\\n\",\n        \"                       kernel_size=1,\\n\",\n        \"                       strides=strides,\\n\",\n        \"                       activation=activation,\\n\",\n        \"                       batch_normalization=batch_normalization,\\n\",\n        \"                       conv_first=False)\\n\",\n        \"      y = resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\\n\",\n        \"      y = resnet_layer(inputs=y,\\n\",\n        \"                       num_filters=num_filters_out,\\n\",\n        \"                       kernel_size=1,\\n\",\n        \"                       conv_first=False)\\n\",\n        \"      if res_block == 0:\\n\",\n        \"        # linear projection residual shortcut connection to match\\n\",\n        \"        # changed dims\\n\",\n        \"        x = resnet_layer(inputs=x,\\n\",\n        \"                         num_filters=num_filters_out,\\n\",\n        \"                         kernel_size=1,\\n\",\n        \"                         strides=strides,\\n\",\n        \"                         activation=None,\\n\",\n        \"                         batch_normalization=False)\\n\",\n        \"      x = layers.Add()([x, y])\\n\",\n        \"\\n\",\n        \"    num_filters_in = num_filters_out\\n\",\n        \"\\n\",\n        \"  # Add classifier on top.\\n\",\n        \"  # v2 has BN-ReLU before Pooling\\n\",\n        \"  x = layers.BatchNormalization()(x)\\n\",\n        \"  x = layers.Activation('relu')(x)\\n\",\n        \"  x = layers.AveragePooling2D(pool_size=8)(x)\\n\",\n        \"  y = layers.Flatten()(x)\\n\",\n        \"  outputs = layers.Dense(num_classes,\\n\",\n        \"                         activation='softmax',\\n\",\n        \"                         kernel_initializer='he_normal')(y)\\n\",\n        \"\\n\",\n        \"  # Instantiate model.\\n\",\n        \"  model = tf.keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"  return model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"hpSh8fWKiWO7\"\n      },\n      \"source\": [\n        \"## TPU Set Up\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"c_JSG9X7iaxu\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def get_tpu_address():\\n\",\n        \"  if 'TPU_NAME' in os.environ and 'COLAB_TPU_ADDR' in os.environ:  # public colab\\n\",\n        \"    assert os.environ['COLAB_GPU'] == '0'\\n\",\n        \"    TPU_ADDRESS = os.environ['TPU_NAME']\\n\",\n        \"    from google.colab import auth\\n\",\n        \"    auth.authenticate_user()\\n\",\n        \"    print('Running on public colab https://colab.research.google.com')\\n\",\n        \"  elif 'TPU_NAME' in os.environ and not 'COLAB_TPU_ADDR' in os.environ:  # Cloud TPU\\n\",\n        \"    TPU_ADDRESS = os.environ['TPU_NAME']\\n\",\n        \"    print('Running on Cloud TPU')\\n\",\n        \"  else:\\n\",\n        \"    raise ValueError('Unknown environment')\\n\",\n        \"  return TPU_ADDRESS\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"5LaK9He6Yw4B\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(\\n\",\n        \"    tpu=get_tpu_address())\\n\",\n        \"tf.tpu.experimental.initialize_tpu_system(cluster_resolver)\\n\",\n        \"tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"dAUaN-i9tHMY\"\n      },\n      \"source\": [\n        \"## Training\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"ClkFCVbDp0HK\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def get_optimizer(model, loss, global_step):\\n\",\n        \"  decayed_learning_rate = tf.train.exponential_decay(init_learning_rate,\\n\",\n        \"                                                     global_step=global_step,\\n\",\n        \"                                                     decay_rate=lr_decay_rate,\\n\",\n        \"                                                     decay_steps=1)\\n\",\n        \"  learning_rate = tf.maximum(decayed_learning_rate, final_learning_rate)\\n\",\n        \"\\n\",\n        \"  if optimizer_name == 'kfac':\\n\",\n        \"    decayed_damping = tf.train.exponential_decay(init_damping,\\n\",\n        \"                                                 global_step=global_step,\\n\",\n        \"                                                 decay_rate=damping_decay_rate,\\n\",\n        \"                                                 decay_steps=1)\\n\",\n        \"    damping = tf.maximum(decayed_damping, final_damping)\\n\",\n        \"    # We cannot use the Keras version because Keras optimizers do not support\\n\",\n        \"    # a global_step argument for minimize. Instead, we use the Keras automated\\n\",\n        \"    # layed collection functionality to get our layer collection.\\n\",\n        \"    lc = kfac.keras.utils.get_layer_collection(\\n\",\n        \"        model=model, loss=loss, seed=SEED)\\n\",\n        \"    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\\n\",\n        \"        learning_rate=learning_rate,\\n\",\n        \"        damping=damping,\\n\",\n        \"        momentum=momentum,\\n\",\n        \"        layer_collection=lc,\\n\",\n        \"        # Replica round robin places each inverse operations on a different \\n\",\n        \"        # replica (TPU core) so that each inverse is computed on one replica\\n\",\n        \"        # then the replicas are synced.\\n\",\n        \"        placement_strategy='replica_round_robin')\\n\",\n        \"\\n\",\n        \"  elif optimizer_name == 'adam':\\n\",\n        \"    decayed_epsilon = tf.train.exponential_decay(init_epsilon,\\n\",\n        \"                                                 global_step=global_step,\\n\",\n        \"                                                 decay_rate=epsilon_decay_rate,\\n\",\n        \"                                                 decay_steps=1)\\n\",\n        \"    epsilon = tf.maximum(decayed_epsilon, final_epsilon)\\n\",\n        \"    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,\\n\",\n        \"                                       beta1=momentum,\\n\",\n        \"                                       epsilon=epsilon)\\n\",\n        \"  return optimizer\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"QuFSOAd-irxw\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def train_step_fn(info, loss_metric, accuracy_metric):\\n\",\n        \"  # We create the model in the train step, but we want to return a reference to\\n\",\n        \"  # it so it can be used for validation. We return a reference to the model list\\n\",\n        \"  # which will be populated after the train_step is run.\\n\",\n        \"  model_list = []\\n\",\n        \"  def train_step(inputs):\\n\",\n        \"    # Need this for layer collection to work correctly. Also, by setting this\\n\",\n        \"    # to 1, batchnorm statistics are computed in this pass.\\n\",\n        \"    tf.keras.backend.set_learning_phase(1)\\n\",\n        \"\\n\",\n        \"    img, labels = inputs\\n\",\n        \"\\n\",\n        \"    # The model needs to be created in the train step for KFAC's layer\\n\",\n        \"    # collection. TPU Strategy autographs this function, so if the model is\\n\",\n        \"    # constructed outside the train step, KFAC's layer collection will capture\\n\",\n        \"    # the wrong input/output tensors.\\n\",\n        \"    # Since TPUs do not support placeholders, we must construct our model\\n\",\n        \"    # directly with the input tensor.\\n\",\n        \"    model = resnet_v2(input_tensor=img,\\n\",\n        \"                      depth=20,\\n\",\n        \"                      num_classes=info['num_classes'])\\n\",\n        \"    model_list.append(model)\\n\",\n        \"\\n\",\n        \"    # Since we constructed our model with the input tensor, the model.output\\n\",\n        \"    # is equivalent to model(img). In a non TPU custom training loop, you can\\n\",\n        \"    # use model(img) instead.\\n\",\n        \"    logits = model.output\\n\",\n        \"    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(\\n\",\n        \"        labels=labels, logits=logits)\\n\",\n        \"    regularization_loss = tf.reduce_sum(model.losses)\\n\",\n        \"    cross_entropy_loss = tf.reduce_mean(cross_entropy)\\n\",\n        \"    # When using Distribution Strategy with KFAC, you must NOT scale the loss.\\n\",\n        \"    loss = regularization_loss + cross_entropy_loss\\n\",\n        \"\\n\",\n        \"    update_loss = loss_metric.update_state(loss)\\n\",\n        \"    update_accuracy = accuracy_metric.update_state(y_true=labels, y_pred=logits)\\n\",\n        \"\\n\",\n        \"    global_step = tf.train.get_or_create_global_step()\\n\",\n        \"\\n\",\n        \"    optimizer = get_optimizer(model=model,\\n\",\n        \"                              loss='sparse_categorical_crossentropy',\\n\",\n        \"                              global_step=global_step)\\n\",\n        \"\\n\",\n        \"    train_op = optimizer.minimize(loss,\\n\",\n        \"                                  var_list=model.trainable_weights,\\n\",\n        \"                                  global_step=global_step)\\n\",\n        \"\\n\",\n        \"    # Control dependencies ensures updates are run before the loss is returned\\n\",\n        \"    with tf.control_dependencies([train_op, update_loss, update_accuracy]):\\n\",\n        \"      return tf.identity(loss)\\n\",\n        \"\\n\",\n        \"  return train_step, model_list\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"_FfEPQAvrTA1\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def eval_step_fn(model, loss_metric, accuracy_metric):\\n\",\n        \"  \\\"\\\"\\\"For validation or test.\\\"\\\"\\\"\\n\",\n        \"\\n\",\n        \"  def eval_step(inputs):\\n\",\n        \"    tf.keras.backend.set_learning_phase(0)\\n\",\n        \"\\n\",\n        \"    img, labels = inputs\\n\",\n        \"    logits = model(img, training=False)\\n\",\n        \"    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(\\n\",\n        \"        labels=labels, logits=logits)\\n\",\n        \"    regularization_loss = tf.reduce_sum(model.losses)\\n\",\n        \"    cross_entropy_loss = tf.reduce_mean(cross_entropy)\\n\",\n        \"    loss = regularization_loss + cross_entropy_loss\\n\",\n        \"\\n\",\n        \"    update_loss = loss_metric.update_state(loss)\\n\",\n        \"    update_accuracy = accuracy_metric.update_state(y_true=labels, y_pred=logits)\\n\",\n        \"\\n\",\n        \"    with tf.control_dependencies([update_loss, update_accuracy]):\\n\",\n        \"      return tf.identity(loss)\\n\",\n        \"\\n\",\n        \"  return eval_step\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"Hf5WFHYP8tT9\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"tf.reset_default_graph()\\n\",\n        \"\\n\",\n        \"with tpu_strategy.scope():\\n\",\n        \"  data, info = get_input_pipeline(batch_size=batch_size,\\n\",\n        \"                                  seed=SEED,\\n\",\n        \"                                  drop_remainder=True,\\n\",\n        \"                                  repeat_validation=False)\\n\",\n        \"\\n\",\n        \"  train_iterator = tpu_strategy.make_dataset_iterator(data['train'])\\n\",\n        \"  val_iterator = tpu_strategy.make_dataset_iterator(data['validation'])\\n\",\n        \"\\n\",\n        \"  train_loss_metric = tf.keras.metrics.Mean(\\n\",\n        \"      'training_loss', dtype=tf.float32)\\n\",\n        \"  train_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(\\n\",\n        \"      'training_accuracy', dtype=tf.float32)\\n\",\n        \"  val_loss_metric = tf.keras.metrics.Mean(\\n\",\n        \"      'val_loss', dtype=tf.float32)\\n\",\n        \"  val_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(\\n\",\n        \"      'val_accuracy', dtype=tf.float32)\\n\",\n        \"\\n\",\n        \"  train_step, model_list = train_step_fn(\\n\",\n        \"      info, train_loss_metric, train_accuracy_metric)\\n\",\n        \"  # experimental_local_results gives us a list of the loss values from each\\n\",\n        \"  # replica. Since we're tracking loss via the Keras Metric, we don't need to\\n\",\n        \"  # worry about reporting (or reducing) this value. If we were to record this\\n\",\n        \"  # value, we should do a mean across replicas since each replica will return an\\n\",\n        \"  # unscaled loss and each replica has the same batch size.\\n\",\n        \"  train_step_op = tpu_strategy.experimental_local_results(\\n\",\n        \"      tpu_strategy.experimental_run(train_step, train_iterator))\\n\",\n        \"\\n\",\n        \"  model = model_list[0]  # There will only be one model in the list.\\n\",\n        \"  val_step = eval_step_fn(model, val_loss_metric, val_accuracy_metric)\\n\",\n        \"  val_step_op = tpu_strategy.experimental_local_results(\\n\",\n        \"      tpu_strategy.experimental_run(val_step, val_iterator))\\n\",\n        \"\\n\",\n        \"  all_variables = (\\n\",\n        \"      tf.global_variables() +\\n\",\n        \"      train_loss_metric.variables + train_accuracy_metric.variables +\\n\",\n        \"      val_loss_metric.variables + val_accuracy_metric.variables\\n\",\n        \"  )\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 0,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"D2w08M2dtwyL\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Without this config, TensorFlow will attempt to place two connected ops on\\n\",\n        \"# different devices, which will cause an InvalidArgumentError.\\n\",\n        \"config = tf.ConfigProto()\\n\",\n        \"config.allow_soft_placement = True\\n\",\n        \"cluster_spec = cluster_resolver.cluster_spec()\\n\",\n        \"if cluster_spec:\\n\",\n        \"  config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())\\n\",\n        \"\\n\",\n        \"with tf.Session(cluster_resolver.master(), config=config) as session:\\n\",\n        \"  session.run([v.initializer for v in all_variables])\\n\",\n        \"  session.run(train_iterator.initializer)\\n\",\n        \"  print('Starting training...')\\n\",\n        \"  for step in range(num_training_steps):\\n\",\n        \"    session.run(train_step_op)\\n\",\n        \"\\n\",\n        \"    if step % steps_per_epoch == 0:\\n\",\n        \"      session.run(val_iterator.initializer)\\n\",\n        \"      for _ in range(val_steps):\\n\",\n        \"        session.run(val_step_op)\\n\",\n        \"\\n\",\n        \"      print('================ Step {} ================'.format(step))\\n\",\n        \"      # The printed train loss is the mean over the entire epoch.\\n\",\n        \"      print('Train Loss {}'.format(session.run(train_loss_metric.result())))\\n\",\n        \"      print('Train Accuracy {}'.format(\\n\",\n        \"          session.run(train_accuracy_metric.result())))\\n\",\n        \"      print('Val Loss {}'.format(session.run(val_loss_metric.result())))\\n\",\n        \"      print('Val Accuracy {}'.format(\\n\",\n        \"          session.run(val_accuracy_metric.result())))\\n\",\n        \"      train_loss_metric.reset_states()\\n\",\n        \"      train_accuracy_metric.reset_states()\\n\",\n        \"      val_loss_metric.reset_states()\\n\",\n        \"      val_accuracy_metric.reset_states()\\n\",\n        \"\\n\",\n        \"  print('Done training')\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [\n        \"_DDaAex5Q7u-\",\n        \"v3vSki-usp9k\",\n        \"SLvlpsups2aR\"\n      ],\n      \"last_runtime\": {\n        \"build_target\": \"\",\n        \"kind\": \"local\"\n      },\n      \"name\": \"KFAC vs Adam on CIFAR10 - TPU.ipynb\",\n      \"provenance\": [\n        {\n          \"file_id\": \"1GOgzfQLpg5aoq_uajqcqLqFTY0ohduEr\",\n          \"timestamp\": 1565229974969\n        },\n        {\n          \"file_id\": \"1pqtoYduODZyJKt4-kwVkt_KtNQCnaNDp\",\n          \"timestamp\": 1565044838251\n        }\n      ],\n      \"version\": \"0.3.2\"\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 2\",\n      \"name\": \"python2\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "kfac/examples/mnist.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utilities for loading MNIST into TensorFlow.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\n__all__ = [\n    'load_mnist_as_tensors',\n    'load_mnist_as_dataset',\n    'load_mnist_as_iterator',\n]\n\n\ndef load_mnist_as_tensors(flatten_images=True, dtype=tf.float32):\n  \"\"\"Loads MNIST as Tensors.\n\n  Args:\n    flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into\n      [784]-shaped vectors.\n    dtype: The TF dtype to return the images as.\n\n  Returns:\n    images, labels, num_examples\n  \"\"\"\n\n#   mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(\n#       '/tmp/mnist', reshape=flatten_images)\n#   num_examples = len(mnist_data.train.labels)\n#   images = mnist_data.train.images\n#   labels = mnist_data.train.labels\n#\n#   images = tf.constant(np.asarray(images, dtype=np.float32))\n#   labels = tf.constant(np.asarray(labels, dtype=np.int64))\n#\n#   return images, labels, num_examples\n\n  (images, labels), _ = tf.keras.datasets.mnist.load_data()\n  num_examples = images.shape[0]\n\n  if flatten_images:\n    images = images.reshape(images.shape[0], 28**2)\n  else:\n    images = images.reshape(images.shape[0], 28, 28, 1)\n\n  images = images.astype('float64')\n  labels = labels.astype('int32')\n\n  images /= 255.\n\n  images = tf.constant(images, dtype=dtype)\n  labels = tf.constant(labels)\n\n  return images, labels, num_examples\n\n\ndef load_mnist_as_dataset(flatten_images=True):\n  \"\"\"Loads MNIST as a Dataset object.\n\n  Args:\n    flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into\n      [784]-shaped vectors.\n\n  Returns:\n    dataset, num_examples, where dataset is a Dataset object containing the\n    whole MNIST training dataset and num_examples is the number of examples\n    in the MNIST dataset (should be 60000).\n  \"\"\"\n  images, labels, num_examples = load_mnist_as_tensors(\n      flatten_images=flatten_images)\n  dataset = tf.data.Dataset.from_tensor_slices((images, labels))\n  return dataset, num_examples\n\n\ndef load_mnist_as_iterator(num_epochs, batch_size,\n                           use_fake_data=False,\n                           flatten_images=True):\n  \"\"\"Loads MNIST dataset as an iterator Tensor.\n\n  Args:\n    num_epochs: int. Number of passes to make over the dataset.\n    batch_size: int. Number of examples per minibatch.\n    use_fake_data: bool. If True, generate a synthetic dataset rather than\n      reading MNIST in.\n    flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into\n      [784]-shaped vectors.\n\n  Returns:\n    examples: Tensor of shape [batch_size, 784] if 'flatten_images' is\n      True, else [batch_size, 28, 28, 1]. Each row is one example.\n      Values in [0, 1].\n    labels: Tensor of shape [batch_size]. Indices of integer corresponding to\n      each example. Values in {0...9}.\n  \"\"\"\n\n  if use_fake_data:\n    rng = np.random.RandomState(42)\n    num_examples = batch_size * 4\n    images = rng.rand(num_examples, 28 * 28)\n    if not flatten_images:\n      images = np.reshape(images, [num_examples, 28, 28, 1])\n    labels = rng.randint(10, size=num_examples)\n    dataset = tf.data.Dataset.from_tensor_slices((np.asarray(\n        images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))\n  else:\n    dataset, num_examples = load_mnist_as_dataset(flatten_images=flatten_images)\n\n  dataset = (dataset.shuffle(num_examples).repeat(num_epochs)\n             .batch(batch_size).prefetch(5))\n  return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()\n"
  },
  {
    "path": "kfac/examples/rnn_mnist.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"RNN trained to do sequential MNIST classification using K-FAC.\n\nThis demonstrates the use of the RNN approximations from the paper\n\"Kronecker-factored Curvature Approximations for Recurrent Neural Networks\".\n\nThe setup here is similar to the autoencoder example.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n# Dependency imports\nfrom absl import flags\nimport kfac\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.examples import mnist\nfrom kfac.python.ops.kfac_utils import data_reader\nfrom kfac.python.ops.kfac_utils import data_reader_alt\n\n\n# We need this for now since linear layers without biases don't work with\n# automatic scanning at the moment\n_INCLUDE_INPUT_BIAS = True\n\n\nflags.DEFINE_string('kfac_approx', 'kron_indep',\n                    'The type of approximation to use for the recurrent '\n                    'layers. \"kron_indep\" is the one which assumes '\n                    'independence across time, \"kron_series_1\" is \"Option 1\" '\n                    'from the paper, and \"kron_series_2\" is \"Option 2\".')\n\nflags.DEFINE_integer('inverse_update_period', 5,\n                     '# of steps between computing inverse of Fisher factor '\n                     'matrices.')\nflags.DEFINE_integer('cov_update_period', 1,\n                     '# of steps between computing covaraiance matrices.')\nflags.DEFINE_integer('damping_adaptation_interval', 5,\n                     '# of steps between updating the damping parameter.')\n\nflags.DEFINE_float('learning_rate', 3e-4,\n                   'Learning rate to use when adaptation=\"off\".')\nflags.DEFINE_float('momentum', 0.9,\n                   'Momentum decay value to use when '\n                   'lrmu_adaptation=\"off\" or \"only_lr\".')\n\nflags.DEFINE_boolean('use_batch_size_schedule', True,\n                     'If True then we use the growing mini-batch schedule from '\n                     'the original K-FAC paper.')\nflags.DEFINE_integer('batch_size', 1024,\n                     'The size of the mini-batches to use if not using the '\n                     'schedule.')\n\nflags.DEFINE_string('lrmu_adaptation', 'on',\n                    'If set to \"on\" then we use the quadratic model '\n                    'based learning-rate and momentum adaptation method from '\n                    'the original paper. Note that this only works well in '\n                    'practice when use_batch_size_schedule=True. Can also '\n                    'be set to \"off\" and \"only_lr\", which turns '\n                    'it off, or uses a version where the momentum parameter '\n                    'is fixed (resp.).')\n\n\nflags.DEFINE_boolean('use_alt_data_reader', True,\n                     'If True we use the alternative data reader for MNIST '\n                     'that is faster for small datasets.')\n\nflags.DEFINE_integer('num_hidden', 128, 'Hidden state dimension of the RNN.')\n\nflags.DEFINE_boolean('use_auto_registration', False,\n                     'Whether to use the automatic registration feature.')\n\nflags.DEFINE_string('device', '/gpu:0',\n                    'The device to run the major ops on.')\n\n\nFLAGS = flags.FLAGS\n\n\ndef make_train_op(batch_size,\n                  batch_loss,\n                  layer_collection,\n                  loss_fn,\n                  cached_reader):\n  \"\"\"Constructs optimizer and train op.\n\n  Args:\n    batch_size: Tensor of shape (), Size of the training batch.\n    batch_loss: Tensor of shape (), Loss with respect to minibatch to be\n      minimzed.\n    layer_collection: LayerCollection or None. Registry for model parameters.\n      Required when using a K-FAC optimizer.\n    loss_fn: Function which takes as input training data and returns loss.\n    cached_reader: `data_reader.CachedReader` instance.\n\n  Returns:\n    train_op: Op that can be used to update model parameters.\n    optimizer: Optimizer used to produce train_op.\n\n  Raises:\n    ValueError: If layer_collection is None when K-FAC is selected as an\n      optimization method.\n  \"\"\"\n  global_step = tf.train.get_or_create_global_step()\n\n  if layer_collection is None:\n    raise ValueError('layer_collection must be defined to use K-FAC.')\n\n  if FLAGS.lrmu_adaptation == 'on':\n    learning_rate = None\n    momentum = None\n    momentum_type = 'qmodel'\n  elif FLAGS.lrmu_adaptation == 'only_lr':\n    learning_rate = None\n    momentum = FLAGS.momentum\n    momentum_type = 'qmodel_fixedmu'\n  elif FLAGS.lrmu_adaptation == 'off':\n    learning_rate = FLAGS.learning_rate\n    momentum = FLAGS.momentum\n    # momentum_type = 'regular'\n    momentum_type = 'adam'\n\n  optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n      invert_every=FLAGS.inverse_update_period,\n      cov_update_every=FLAGS.cov_update_period,\n      learning_rate=learning_rate,\n      damping=150.,  # When using damping adaptation it is advisable to start\n                     # with a high value. This value is probably far too high\n                     # to use for most neural nets if you aren't using damping\n                     # adaptation. (Although it always depends on the scale of\n                     # the loss.)\n      cov_ema_decay=0.95,\n      momentum=momentum,\n      momentum_type=momentum_type,\n      layer_collection=layer_collection,\n      batch_size=batch_size,\n      num_burnin_steps=5,\n      adapt_damping=True,\n      is_chief=True,\n      prev_train_batch=cached_reader.cached_batch,\n      loss=batch_loss,\n      loss_fn=loss_fn,\n      damping_adaptation_decay=0.95,\n      damping_adaptation_interval=FLAGS.damping_adaptation_interval,\n      min_damping=1e-5\n      )\n  return optimizer.minimize(batch_loss, global_step=global_step), optimizer\n\n\ndef eval_model(x, num_classes, layer_collection=None):\n  \"\"\"Evaluate the model given the data and possibly register it.\"\"\"\n\n  num_hidden = FLAGS.num_hidden\n  num_timesteps = x.shape[1]\n  num_input = x.shape[2]\n\n  # Strip off the annoying last dimension of size 1 (added for convenient use\n  # with conv nets).\n  x = x[..., 0]\n\n  # Unstack to get a list of 'num_timesteps' tensors of\n  # shape (batch_size, num_input)\n  x_unstack = tf.unstack(x, num_timesteps, 1)\n\n  # We need to do this manually without cells since we need to get access\n  # to the pre-activations (i.e. the output of the \"linear layers\").\n  w_in = tf.get_variable('w_in', shape=[num_input, num_hidden])\n  if _INCLUDE_INPUT_BIAS:\n    b_in = tf.get_variable('b_in', shape=[num_hidden])\n\n  w_rec = tf.get_variable('w_rec', shape=[num_hidden, num_hidden])\n  b_rec = tf.get_variable('b_rec', shape=[num_hidden])\n\n  a = tf.zeros([tf.shape(x_unstack[0])[0], num_hidden], dtype=tf.float32)\n\n  # Here 'a' are the activations, 's' the pre-activations\n  a_list = []\n  s_in_list = []\n  s_rec_list = []\n  s_list = []\n\n  for input_ in x_unstack:\n\n    a_list.append(a)\n\n    s_in = tf.matmul(input_, w_in)\n    if _INCLUDE_INPUT_BIAS:\n      s_in += b_in\n    s_rec = tf.matmul(a, w_rec) + b_rec\n    # s_rec = b_rec + tf.matmul(a, w_rec)  # this breaks the graph scanner\n    s = s_in + s_rec\n\n    s_in_list.append(s_in)\n    s_rec_list.append(s_rec)\n    s_list.append(s)\n\n    a = tf.tanh(s)\n\n  final_rnn_output = a\n\n  # NOTE: we can uncomment the lines below without changing how the algorithm\n  # behaves.  This is because the derivative of the loss w.r.t. to s is the\n  # the same as it is for both s_in and s_rec.  This can be seen easily from\n  # the chain rule.\n  #\n  # s_rec_list = s_list\n  # s_in_list = s_list\n\n  if _INCLUDE_INPUT_BIAS:\n    pin = (w_in, b_in)\n  else:\n    pin = w_in\n\n  if layer_collection:\n    layer_collection.register_fully_connected_multi(pin, x_unstack,\n                                                    s_in_list,\n                                                    approx=FLAGS.kfac_approx)\n\n    layer_collection.register_fully_connected_multi((w_rec, b_rec), a_list,\n                                                    s_rec_list,\n                                                    approx=FLAGS.kfac_approx)\n\n  # Output parameters (need this no matter how we construct the RNN):\n  w_out = tf.get_variable('w_out', shape=[num_hidden, num_classes])\n  b_out = tf.get_variable('b_out', shape=[num_classes])\n\n  logits = tf.matmul(final_rnn_output, w_out) + b_out\n\n  if layer_collection:\n    layer_collection.register_fully_connected((w_out, b_out), final_rnn_output,\n                                              logits)\n\n  return logits\n\n\ndef compute_loss(inputs, labels, num_classes, layer_collection=None):\n  \"\"\"Compute loss value.\"\"\"\n\n  with tf.variable_scope('model', reuse=tf.AUTO_REUSE):\n    if FLAGS.use_auto_registration:\n      logits = eval_model(inputs, num_classes)\n    else:\n      logits = eval_model(inputs, num_classes,\n                          layer_collection=layer_collection)\n\n  losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,\n                                                          labels=labels)\n  loss = tf.reduce_mean(losses)\n\n  if layer_collection is not None:\n    layer_collection.register_softmax_cross_entropy_loss(logits)\n    if FLAGS.use_auto_registration:\n      layer_collection.auto_register_layers()\n\n  return loss\n\n\ndef load_mnist():\n  \"\"\"Creates MNIST dataset and wraps it inside cached data reader.\n\n  Returns:\n    cached_reader: `data_reader.CachedReader` instance which wraps MNIST\n      dataset.\n    num_examples: int. The number of training examples.\n  \"\"\"\n  # Wrap the data set into cached_reader which provides variable sized training\n  # and caches the read train batch.\n\n  if not FLAGS.use_alt_data_reader:\n    # Version 1 using data_reader.py (slow!)\n    dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=False)\n    if FLAGS.use_batch_size_schedule:\n      max_batch_size = num_examples\n    else:\n      max_batch_size = FLAGS.batch_size\n\n    # Shuffle before repeat is correct unless you want repeat cases in the\n    # same batch.\n    dataset = (dataset.shuffle(num_examples).repeat()\n               .batch(max_batch_size).prefetch(5))\n    dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()\n\n    # This version of CachedDataReader requires the dataset to be shuffled\n    return data_reader.CachedDataReader(dataset, max_batch_size), num_examples\n\n  else:\n    # Version 2 using data_reader_alt.py (faster)\n    images, labels, num_examples = mnist.load_mnist_as_tensors(\n        flatten_images=False)\n    dataset = (images, labels)\n\n    # This version of CachedDataReader requires the dataset to NOT be shuffled\n    return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples\n\n\ndef main(_):\n  # Load dataset.\n  cached_reader, num_examples = load_mnist()\n  num_classes = 10\n\n  minibatch_maxsize_targetiter = 500\n  minibatch_maxsize = num_examples\n  minibatch_startsize = 1000\n\n  div = (float(minibatch_maxsize_targetiter-1)\n         / math.log(float(minibatch_maxsize)/minibatch_startsize, 2))\n  batch_size_schedule = [\n      min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize)\n      for k in range(500)\n  ]\n\n  batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size')\n\n  layer_collection = kfac.LayerCollection()\n\n  def loss_fn(minibatch, layer_collection=None):\n    return compute_loss(minibatch[0], minibatch[1], num_classes,\n                        layer_collection=layer_collection)\n\n  minibatch = cached_reader(batch_size)\n  batch_loss = loss_fn(minibatch, layer_collection=layer_collection)\n\n  # Make training op\n  with tf.device(FLAGS.device):\n    train_op, opt = make_train_op(\n        batch_size,\n        batch_loss,\n        layer_collection,\n        loss_fn=loss_fn,\n        cached_reader=cached_reader)\n\n  learning_rate = opt.learning_rate\n  momentum = opt.momentum\n  damping = opt.damping\n  rho = opt.rho\n  qmodel_change = opt.qmodel_change\n  global_step = tf.train.get_or_create_global_step()\n\n  # Without setting allow_soft_placement=True there will be problems when\n  # the optimizer tries to place certain ops like \"mod\" on the GPU (which isn't\n  # supported).\n  config = tf.ConfigProto(allow_soft_placement=True)\n\n  # Train model.\n  with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30,\n                                         config=config) as sess:\n    while not sess.should_stop():\n      i = sess.run(global_step)\n\n      if FLAGS.use_batch_size_schedule:\n        batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)]\n      else:\n        batch_size_ = FLAGS.batch_size\n\n      _, batch_loss_ = sess.run([train_op, batch_loss],\n                                feed_dict={batch_size: batch_size_})\n\n      # We get these things in a separate sess.run() call because they are\n      # stored as variables in the optimizer. (So there is no computational cost\n      # to getting them, and if we don't get them after the previous call is\n      # over they might not be updated.)\n      (learning_rate_, momentum_, damping_, rho_,\n       qmodel_change_) = sess.run([learning_rate, momentum, damping, rho,\n                                   qmodel_change])\n\n      # Print training stats.\n      tf.logging.info(\n          'iteration: %d', i)\n      tf.logging.info(\n          'mini-batch size: %d | mini-batch loss = %f',\n          batch_size_, batch_loss_)\n      tf.logging.info(\n          'learning_rate = %f | momentum = %f',\n          learning_rate_, momentum_)\n      tf.logging.info(\n          'damping = %f | rho = %f | qmodel_change = %f',\n          damping_, rho_, qmodel_change_)\n      tf.logging.info('----')\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.app.run(main)\n\n\n"
  },
  {
    "path": "kfac/python/__init__.py",
    "content": "\n"
  },
  {
    "path": "kfac/python/keras/README.md",
    "content": "# K-FAC for Keras\n\n**K-FAC for Keras** is an implementation of K-FAC, an approximate second-order\noptimization method, in TensorFlow. You can read more about it in the paper\n[here][paper] and the GitHub docs [here][index].\n\n[index]: https://github.com/tensorflow/kfac/tree/master/docs/index.md\n[paper]: https://arxiv.org/abs/1503.05671\n\n## Why should I use K-FAC for Keras?\n\nIn addition to the reasons outlined on the GitHub docs, the Keras version\nhandles layer and loss registration automatically and works with Keras's\nconvenient training API. See the reference code [here][cifar10].\n\n[cifar10]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb\n[cifar10tpu]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb\n\n## How do I use K-FAC for Keras?\n\nUsing this optimizer is almost the same as using any other Keras optimizer,\nexcept you must also pass the loss and model to the optimizer. The optimizer\nwill automatically register the model layers and loss so K-FAC can compute the\nfisher approximations.\n\n```python\nimport tensorflow.compat.v1 as tf\nimport kfac\n\n# Build Keras Model (can use functional or sequential)\nmodel = tf.keras.Model(...)\nloss = 'sparse_categorical_crossentropy' # or a tf.keras.losses.* instance\n\n# Construct Optimizer\noptimizer = kfac.keras.optimizers.Kfac(learning_rate=0.001,\n                                       damping=0.01,\n                                       model=model,\n                                       loss=loss)\n\n# Compile and Fit Model\nmodel.compile(optimizer=optimizer, loss=loss, ...)\nmodel.fit(...)\n```\n\nCheck out our CIFAR-10 CNN training [example][cifar10] and\n[TPU Strategy example][cifar10tpu] for more details.\n\nThis optimizer currently supports the following tf.keras.layers types: Conv2D,\nConv1D, Dense, BatchNormalization, LayerNormalization and Embedding. The\nfollowing tf.keras.losses are supported: sparse_categorical_crossentropy,\ncategorical_crossentropy, binary_crossentropy, and mean_squared_error. You may\nuse any architecture with these basic layers and losses, including multiple\nbranches and loss functions.\n\nTo use an unsupported layer or loss, you can register layers manually using\na LayerCollection object and pass that to the optimizer constructor. Examples\nof using LayerCollection are [here][layercollection].\n\n[layercollection]: https://github.com/tensorflow/kfac/tree/master/kfac/examples\n\n## How is K-FAC Different from Other Keras Optimizers?\n\n1.  When using your model as a callable (i.e. `output = model(input)`), `input`\n    must be a Keras layer. If it is a normal tensor, you can wrap it as follows:\n    `new_input = tf.keras.layers.Input(tensor=input)`. This is so Keras\n    registers the layer as an inbound_node during the call, allowing our layer\n    collection to register it correctly. By default, our automatic layer\n    collection will register only the latest use of the model.\n2.  Only a subset of the hyperparameters can be accessed and modified after\n    instantiation. These are: learning_rate, damping, momentum,\n    weight_decay_coeff, norm_constraint, and batch_size. These hyperparameters\n    will work the same as normal hyperparameters in native Keras optimizers and\n    can be used with tools like hyperparameter scheduler callbacks. You can see\n    exactly which hyperparameters are modifiable by checking the\n    `optimizer.mutable_hyperparameters` property. Note that damping cannot be\n    modified when using adaptive damping, and momentum/learning_rate cannot be\n    modified when using qmodel momentum. Also, if any of the hyperparameters are\n    `None` during instantiation, they will not be modifiable during training.\n3.  This optimizer is tested with TPUStrategy and MirroredStrategy. However,\n    you may not use a Strategy with model.fit for two reasons. First, we expect\n    an unscaled loss (i.e. it should NOT be scaled by 1.0 / global_batch_size).\n    Second, TPUStrategy will autograph the train step, so your model and\n    optimizer must both be created in the train step for KFAC to work. This is\n    not possible with model.fit. See our [CIFAR10 TPU][cifar10tpu] example for\n    details on how to do this.\n4.  This optimizer is fully compatible with tf.keras.models.save_model or\n    model.save(). To load the compiled model with the optimizer, you must use\n    our saving_utils.load_model method, which is identical to\n    tf.keras.models.load_model except it registers the model with the optimizer\n    after compiling the model and before loading the optimizer's weights.\n    Example:\n\n    ```python\n    import tensorflow as tf\n    import kfac\n\n    model = tf.keras.Model(...)\n    loss = tf.keras.losses.MSE()  # could be a serialized loss function\n    optimizer = kfac.keras.optimizers.Kfac(learning_rate=0.001,\n                                           damping=0.01,\n                                           model=model,\n                                           loss=loss)\n    model.compile(optimizer, loss)\n    model.fit(...)\n    model.save('saved_model.hdf5')  # or tf.keras.models.save_model(model)\n    ...\n    loaded_model = kfac.keras.saving_utils.load_model('saved_model.hdf5')\n    loaded_model.fit(...)\n    ```\n\n## EXPERIMENTAL - How can I use the adaptive damping/momentum/learning rate?\n\nThe original [KFAC paper][paper] outlines how the optimizer can automatically\nadjust the learning rate, momentum, and damping. You can use it as follows:\n\n```python\nimport tensorflow.compat.v1 as tf\nfrom tensorflow_kfac.keras import kfac_optimizer\n\n# tf.data.Dataset dataset\ndataset = ...\ndataset = dataset.shuffle(...).repeat().batch(..., drop_remainder=True)\ntrain_batch = train_batch.get_one_shot_iterator().get_next() # (x, y) tensors\n\nmodel = tf.keras.Model(...)\nloss = 'sparse_categorical_crossentropy'\n\n# Construct Optimizer\noptimizer = kfac.keras.optimizers..Kfac(damping=10.0,\n                                        adaptive=True,\n                                        model=model,\n                                        loss=loss,\n                                        train_batch=train_batch,\n                                        ...)\n\n# Compile and Fit Model\nmodel.compile(optimizer=optimizer, loss=loss, ...)\nmodel.fit(train_batch, ...)\n```\n\nIf your batch size is not fixed at the start of training (i.e. it has an ?\ndimension, such as when `drop_remainder=False`), you must pass the `batch_size`\nin the constructor. If you do not use `optimizer.minimize(...)`, you must\npass in the `loss_tensor`. If you use a custom loss function, you must pass in\nthe `loss_fn` in the constructor. Look at the documentation for the\nTensorFlow KFAC optimizer for details on how to customize this more.\n\nNote that this feature is experimental, so it is not recommended for standard\nuse cases. It works best when used with a high initial damping (10.0-100.0), and\nwith a large batch size. The [autoencoder example][ae_eg] shows using the\nadaptive damping and qmodel momentum successfully.\n\n[ae_eg]: https://github.com/tensorflow/kfac/blob/master/kfac/examples/autoencoder_mnist.py\n"
  },
  {
    "path": "kfac/python/keras/__init__.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"KFAC for Keras.\"\"\"\n\nfrom kfac.python.keras import callbacks\nfrom kfac.python.keras import optimizers\nfrom kfac.python.keras import saving_utils\nfrom kfac.python.keras import utils\n"
  },
  {
    "path": "kfac/python/keras/callbacks.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Hyperparameter Scheduling Callbacks for Keras K-FAC.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nimport six\nimport tensorflow.compat.v1 as tf\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass HyperparameterDecay(tf.keras.callbacks.Callback):\n  \"\"\"Base class for global_step/iterations-based optimizer decay callbacks.\"\"\"\n\n  def __init__(self, hyperparameter, num_delay_steps=0, verbose=0):\n    \"\"\"Construct a new HyperparameterDecay.\n\n    Args:\n      hyperparameter: String specifying the optimizer attribute to decay.\n      num_delay_steps: Integer specifying how many steps to wait before decaying\n        the attribute.\n      verbose: Integer. When > 1, the hyperparameter value is printed every\n        epoch.\n    \"\"\"\n\n    self._hyperparameter = hyperparameter\n    self._num_delay_steps = num_delay_steps\n    self.verbose = verbose\n\n  def on_train_begin(self, logs=None):\n    self._optimizer = self.model.optimizer\n    if not hasattr(self._optimizer, self._hyperparameter):\n      raise ValueError('Optimizer must have a \"{}\" attribute.'\n                       .format(self._hyperparameter))\n    if not hasattr(self._optimizer, 'iterations'):\n      raise ValueError('Optimizer must have a \"iterations\" attribute.')\n\n  def on_epoch_begin(self, epoch, logs=None):\n    if self.verbose > 0:\n      value = float(tf.keras.backend.get_value(getattr(self._optimizer,\n                                                       self._hyperparameter)))\n      print('\\nEpoch {:05}: Current {} is {}.'\n            .format(epoch + 1, self._hyperparameter, value))\n\n  def on_epoch_end(self, epoch, logs=None):\n    if logs is not None:\n      logs[self._hyperparameter] = tf.keras.backend.get_value(\n          getattr(self._optimizer, self._hyperparameter))\n\n  def _get_global_step(self):\n    return (tf.keras.backend.get_value(self._optimizer.iterations)\n            - self._num_delay_steps)\n\n\nclass PolynomialDecay(HyperparameterDecay):\n  \"\"\"Polynomial Optimizer Hyperparameter Schedule.\n\n  Based on https://www.tensorflow.org/api_docs/python/tf/train/polynomial_decay\n\n  The decay applies as follows for num_decay_steps steps when the global_step\n  (i.e. optimizer.iterations) exceeds the num_delay_steps.\n\n    step = global_step - num_delay_steps\n    decayed_value = (init_value - final_value) *\n                    (1 - step / num_decay_steps) ^ (power) + final_value\n  \"\"\"\n\n  def __init__(self,\n               hyperparameter,\n               init_value,\n               final_value,\n               power,\n               num_decay_steps,\n               **kwargs):\n    \"\"\"Construct a new PolynomialDecay Callback.\n\n    Args:\n      hyperparameter: String specifying the optimizer attribute to decay.\n      init_value: Float specifying initial value of the attribute.\n      final_value: Float specifying value of attribute at the end of the decay.\n      power: Float specifying power (exponent) of the polynomial decay.\n      num_decay_steps: Integer, number of steps to decay the attribute.\n      **kwargs: Keyword arguments for HyperparameterDecay. This includes\n        num_delay_steps and verbose.\n    \"\"\"\n    super(PolynomialDecay, self).__init__(hyperparameter, **kwargs)\n    self._init_value = init_value\n    self._final_value = final_value\n    self._power = power\n    self._num_decay_steps = num_decay_steps\n\n  def on_batch_begin(self, batch, logs=None):\n    step = self._get_global_step()\n    if step > 0 and step <= self._num_decay_steps:\n      decayed_value = ((self._init_value - self._final_value) *\n                       (1 - step / self._num_decay_steps) ** (self._power) +\n                       self._final_value)\n      setattr(self._optimizer, self._hyperparameter, decayed_value)\n\n\nclass ExponentialDecay(HyperparameterDecay):\n  \"\"\"Exponential Optimizer Hyperparameter Decay Schedule.\n\n  The decay applies as follows for num_decay_steps steps when the global_step\n  (i.e. optimizer.iterations) exceeds the num_delay_steps. If num_decay_steps\n  is not provided, it will keep decaying for the duration of training.\n\n  When a decay rate and num_decay_steps is provided:\n    step = min(global_step - num_delay_steps, num_decay_steps)\n    decayed_value = init_value * decay_rate^step\n\n  When a decay_rate and final_value are provided:\n    step = global_step - num_delay_steps\n    decayed_value = max(init_value * decay_rate^step, final_value)\n\n  When a final value and num_decay_steps is provided:\n    step = global_step - num_delay_steps\n    decayed_value = init_value *\n                   (final_value / init_value) ^ (step / num_decay_steps)\n  \"\"\"\n\n  def __init__(self,\n               hyperparameter,\n               init_value,\n               final_value=None,\n               decay_rate=None,\n               num_decay_steps=None,\n               **kwargs):\n    \"\"\"Construct a new ExponentialDecay Callback.\n\n    You must specify exactly two of final_value, decay_rate, and\n    num_decay_steps.\n\n\n    Args:\n      hyperparameter: String specifying the optimizer attribute to decay.\n      init_value: Float specifying initial value of the attribute.\n      final_value: Float specifying value of attribute at the end of the decay.\n      decay_rate: Float specifying the decay rate of the decay.\n      num_decay_steps: Integer, number of steps to decay the attribute.\n      **kwargs: Keyword arguments for HyperparameterDecay. This includes\n        num_delay_steps and verbose.\n    \"\"\"\n    super(ExponentialDecay, self).__init__(hyperparameter, **kwargs)\n    self._num_decay_steps = num_decay_steps\n\n    # In theory, we could support more different combinations of final_value,\n    # num_decay_steps, and decay_rate, but for the sake of clarity we will limit\n    # this callback to the below combinations.\n    if final_value and decay_rate and num_decay_steps:\n      raise ValueError('You must specify exactly two of final_value, decay_rate'\n                       ', and num_decay_steps.')\n    if final_value and decay_rate:\n      self._decay_func = lambda step: max(  # pylint: disable=g-long-lambda\n          (init_value * (decay_rate ** step)), final_value)\n    elif decay_rate and num_decay_steps:\n      self._decay_func = lambda step: (init_value * decay_rate ** step)\n    elif final_value and num_decay_steps:\n      self._decay_func = lambda step: (  # pylint: disable=g-long-lambda\n          init_value * (final_value / init_value) **\n          (float(step) / num_decay_steps))\n    else:\n      raise ValueError('You must specify exactly two of final_value, decay_rate'\n                       ', and num_decay_steps.')\n\n  def on_batch_begin(self, batch, logs=None):\n    global_step = self._get_global_step()\n    if (global_step > 0 and\n        (not self._num_decay_steps or global_step <= self._num_decay_steps)):\n      decayed_value = self._decay_func(global_step)\n      setattr(self._optimizer, self._hyperparameter, decayed_value)\n"
  },
  {
    "path": "kfac/python/keras/optimizers.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"KFAC Optimizer for Keras.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport inspect\nimport numbers\nimport re\nfrom absl import logging\nfrom tensorflow.python.keras import backend\nimport six\nimport tensorflow.compat.v1 as tf\n\n\nfrom kfac.python.keras import utils\nfrom kfac.python.ops import optimizer\nfrom kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt\n\n# TODO(b/135110195): Support letting the user choose the TF KFAC optimizer.\n_KFAC_OPT_CLASS = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt\n\n# TODO(b/134945404): Change how default config args are retrieved.\ngetfullargspec = inspect.getfullargspec if six.PY3 else inspect.getargspec\n_KFAC_ARGS = getfullargspec(optimizer.KfacOptimizer.__init__)\n_PERIODIC_KFAC_ARGS = getfullargspec(_KFAC_OPT_CLASS.__init__)\n_DEFAULT_KWARGS = dict(zip(reversed(_KFAC_ARGS.args),\n                           reversed(_KFAC_ARGS.defaults)))\n_DEFAULT_KWARGS.update(zip(reversed(_PERIODIC_KFAC_ARGS.args),\n                           reversed(_PERIODIC_KFAC_ARGS.defaults)))\n\n_MUTABLE_HYPER_PARAMS = {'learning_rate',\n                         'momentum',\n                         'damping',\n                         'weight_decay_coeff',\n                         'norm_constraint'}\n\n\ndef _configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive):\n  \"\"\"Checks and fills in some required kwargs to use an adaptive mode.\n\n  This will set up kfac_kwargs for adaptive, adapt_damping, and/or qmodel\n  momentum, if needed. It will not check for train_batch or batch_size, as that\n  check happens right before the minimize. It will set the following if not set\n  by the user:\n\n  If adaptive=True:\n    - adapt_damping=True, momentum=None, momentum_type='qmodel'\n    - The checks listed below.\n\n  If adapt_damping=True:\n    - use_passed_loss=True, and then it will get the loss_tensor from minimize.\n    - update_damping_immediately=True\n    - damping_adaptation_interval=5 if the user hasn't set this already.\n    - invert_every=5 if the user hasn't set this already.\n\n  If momentum_type='qmodel' or momentum_type='qmodel_fixedmu':\n    - Ensures learning rate and momentum are None.\n\n  Args:\n   kfac_kwargs: dict of keyword arguments to be passed to\n     PeriodicInvCovUpdateKfacOpt.\n   adaptive: bool indicating the optimizer is in adaptive mode.\n  \"\"\"\n  if adaptive:\n    kfac_kwargs.update({\n        'adapt_damping': True,\n        'momentum': None,\n        'momentum_type': 'qmodel',\n    })\n\n  if kfac_kwargs.get('momentum_type', 'regular').lower().startswith('qmodel'):\n    if kfac_kwargs['learning_rate']:\n      raise ValueError('learning_rate must be None to use adaptive/qmodel.')\n    if kfac_kwargs.get('momentum', None):\n      raise ValueError('momentum must be None to use adaptive/qmodel.')\n\n  if kfac_kwargs.get('adapt_damping', False):\n    defaults = {'use_passed_loss': True, 'update_damping_immediately': True}\n    # This way, we keep the user's preferences and only replace missing items.\n    defaults.update(kfac_kwargs)\n    kfac_kwargs.update(defaults)\n\n    if not ('invert_every' in kfac_kwargs and\n            'damping_adaptation_interval' in kfac_kwargs):\n      # damping_adaptation_interval % invert_every must = 0\n      kfac_kwargs['invert_every'] = 5\n      kfac_kwargs['damping_adaptation_interval'] = 5\n\n\nclass Kfac(tf.keras.optimizers.Optimizer):\n  \"\"\"The KFAC Optimizer for Keras.\"\"\"\n\n  def __init__(self,  # pylint: disable=invalid-name\n               _sentinel=None,\n               learning_rate=None,\n               damping=None,\n               model=None,\n               loss=None,\n               loss_weights=None,\n               fisher_approx=None,\n               layer_collection=None,\n               adaptive=False,\n               train_batch=None,\n               name=None,\n               seed=None,\n               **kfac_kwargs):\n    \"\"\"Construct a new KFAC optimizer.\n\n    If you construct this Optimizer without a model with a loss, model and loss,\n    or a layer_collection, you must call register_layers before using the\n    optimizer.\n\n    If you use adaptive, adapt_damping, or qmodel_momentum, this class will set\n    up the required loss functions and tensors. You must pass the train_batch\n    tensors as a tuple (x, y). If the batch_size cannot be inferred from the\n    train_batch[0] tensor, you pass in the batch_size in the constructor. You\n    may not use numpy arrays as input when using the adaptive mode. If you do\n    not use minimize, you must also provide the loss_tensor.\n\n    When using Distribution Strategy, K-FAC expects a loss tensor that is\n    normalized only by the per-replica batch size, and not the total batch size,\n    unlike what is commonly recommended. This means you cannot use K-FAC with\n    a Distribution Strategy and model.fit at the same time, since model.fit\n    does this scaling for you. Instead, use a custom training loop with\n    Distribution Strategy (there are examples in the Github repo).\n\n    Args:\n      _sentinel: Used to prevent positional parameters. Internal, do not use.\n      learning_rate: float or 0D Tensor. Required if not using adapt_damping.\n        Refer to kfac.KfacOptimizer for a detailed description.\n      damping: Required. float or 0D Tensor. Refer to kfac.KfacOptimizer for a\n        detailed description.\n      model: Keras model which this class will optimize. Currently, dense, Conv\n        1D/2D, and embedding are supported as trainable layers.\n      loss: Keras (normal or serialized) loss function. Could be a list or a\n        dictionary mapping layer names to (normal or serialized) loss functions.\n        Currently, sparse/normal categorical/binary cross entropy and MSE are\n        supported.\n      loss_weights: An optional list of coefficients or a dictionary mapping\n        layer names to the coefficient for each loss functions. If it is a list,\n        there must be a the same number of coefficients as loss functions. If\n        it is a dictionary and a coefficient is not given for a loss function,\n        a coefficient of 1.0 will be used.\n      fisher_approx: An optional list of approximations or a dictionary mapping\n        layer name/class to fisher approximation type. If it is a list, there\n        must be the same number of approximations as there are layers with\n        trainable parameters. For each layer, the approximation is determined as\n        follows. If fisher_approx is a dictionary, first we check if the name is\n        in the dict, if it isn't found the layer class is checked, if it isn't\n        found the default is used. When fisher_approx is a list, the order of\n        the approximations must match the order of the layers with trainable\n        parameters given by model.layers. None is a valid dict/list entry and\n        indicates to use the default approximation for that layer.\n      layer_collection: Only use this argument when you have an unsupported\n        model architecture and so manually register the layers. Refer to\n        kfac.KfacOptimizer for a detailed description.\n      adaptive: Whether this optimizer is in adaptive mode or not. In adaptive\n        mode, we set momentum_type='qmodel' and adapt_damping=True, so you must\n        provide the damping (used as the initial value). learning_rate and\n        momentum must be None. You must provide a train_batch and potentially\n        a batch_size if we cannot infer the batch_size from the train_batch.\n      train_batch: A tuple (input, label). The input must be a tensor or a list\n        of tensors that you can call the model on. The label must be a tensor\n        or list of tensors compatible with the loss_fn. See utils.get_loss_fn\n        for the standard loss_fn we create, or you can provide a custom loss_fn.\n      name: Optional name for operations created when applying gradients.\n        Defaults to \"kfac\".\n      seed: Optional integer specifying the TensorFlow random seed. To get\n        deterministic behaviour, the seed needs to be set because the targets\n        are sampled to approximate the fisher.\n      **kfac_kwargs: Additional arguments to be passed to\n        kfac.PeriodicInvCovUpdateKfacOpt (and then to kfac.KfacOptimizer). Note\n        the \"loss\" argument for kfac.KfacOptimizer should be passed as\n        \"loss_tensor\".\n\n    Raises:\n      ValueError: If clipvalue or clipnorm arguments are used.\n      ValueError: If positional arguments are used (or _sentinel is used).\n      ValueError: If damping is not provided.\n      ValueError: If learning_rate or momentum are set with adaptive=True.\n    \"\"\"\n    if tf.executing_eagerly():\n      logging.warn('Eager mode appears to be enabled. Kfac is untested in '\n                   'eager mode.')\n    if _sentinel:\n      raise ValueError('Do not pass positional arguments, only use keyword '\n                       'arguments.')\n    if damping is None:\n      raise ValueError('Please provide a value for damping.')\n\n    if 'clipvalue' in kfac_kwargs:\n      raise ValueError('Argument \"clipvalue\" is not support.')\n    if 'clipnorm' in kfac_kwargs:\n      raise ValueError('Argument \"clipnorm\" is not supported. Use '\n                       '\"norm_constraint\" instead.')\n\n    super(Kfac, self).__init__(name=name)\n\n    kfac_kwargs.update({'name': self._name,\n                        'learning_rate': learning_rate,\n                        'damping': damping})\n\n    _configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive)\n\n    self._optimizer = None\n    self._layer_collection = None\n    self._model = model\n    self._loss = loss\n    self._have_tracked_vars = False\n    self._tf_var_scope = self._name + '/tf_vars'\n    # We use _kfac_kwargs and _config in various parts in the code below.\n    # _kfac_kwargs is checked when we want to know only what the user passed.\n    # _config is used when we want user selections with the default kwargs as a\n    # fallback.\n    self._kfac_kwargs = kfac_kwargs\n    self._layer_collection_kwargs = {\n        'loss_weights': loss_weights,\n        'fisher_approx': utils.serialize_fisher_approx(fisher_approx),\n        'seed': seed,\n    }\n    self._config = _DEFAULT_KWARGS.copy()\n    self._config.update(kfac_kwargs)\n    self._config.update(self._layer_collection_kwargs)\n    self._config['loss'] = utils.serialize_loss(loss)\n\n    if 'loss_tensor' in self._kfac_kwargs:\n      self._kfac_kwargs['loss'] = self._kfac_kwargs.pop('loss_tensor')\n\n    self._mutable_hypers = _MUTABLE_HYPER_PARAMS.copy()\n    if self._config['adapt_damping']:\n      self._mutable_hypers.remove('damping')\n    if self._config['momentum_type'].lower().startswith('qmodel'):\n      self._mutable_hypers -= {'learning_rate', 'momentum'}\n    for hp in self._mutable_hypers.copy():\n      if self._config[hp] is None:\n        self._mutable_hypers.remove(hp)\n      else:\n        self._set_hyper(hp, self._config[hp])\n\n    if layer_collection:\n      self.register_layers(layer_collection=layer_collection)\n    if train_batch and self._kfac_kwargs.get('adapt_damping', False):\n      self.register_train_batch(train_batch=train_batch)\n\n  @property\n  def name(self):\n    # This settable property exists to avoid variable name scope conflicts.\n    return self._name\n\n  @name.setter\n  def name(self, value):\n    if self._optimizer:\n      raise ValueError('Can\\'t change the optimizer\\'s name after the variables'\n                       ' are created')\n    self._name = value\n    self._config['name'] = value\n    self._kfac_kwargs['name'] = value\n    self._tf_var_scope = value + '/tf_vars'\n\n  @property\n  def optimizer(self):\n    # We defer the creation of the optimizer for a few reasons. First, if the\n    # user decides to use the model as a callable, we want to capture the latest\n    # inbound node of the model. Also, this mimics the behaviour of existing\n    # Keras optimizers, as all the variables are created on the first\n    # apply_gradients call (unless the user tries to access this property).\n    # Second, this reduces code duplication as we can use the super class's\n    # _set_hypers and _create_hypers methods. Finally, if the user restores an\n    # optimizer, this allows them to control the variable scope before the\n    # variables are created (to avoid scope conflicts).\n    if not self._optimizer:\n      self._create_optimizer()\n    return self._optimizer\n\n  @property\n  def layers(self):\n    return self._layer_collection\n\n  @property\n  def mutable_hyperparameters(self):\n    return self._mutable_hypers\n\n  def register_layers(self, model=None, loss=None, layer_collection=None):\n    if not layer_collection:\n      if not loss and hasattr(model, 'loss'):\n        loss = model.loss\n      if not (model and loss):\n        raise ValueError('Please provide a model with a loss, a model and loss,'\n                         ' or a LayerCollection')\n      layer_collection = utils.get_layer_collection(\n          model, loss, **self._layer_collection_kwargs)\n    self._layer_collection = layer_collection\n    self._kfac_kwargs['var_list'] = layer_collection.registered_variables\n\n  def register_train_batch(self, train_batch, batch_size=None):\n    \"\"\"Configures the train_batch tuple and batch_size for adaptive damping.\"\"\"\n    if not isinstance(train_batch, tuple):\n      raise ValueError('You must provide the train_batch tuple of inputs to '\n                       'use adaptive/adapt_damping mode.')\n    elif not all(isinstance(inp, tf.Tensor) for inp in train_batch):\n      raise ValueError('You must use TF tensors as input.')\n    self._kfac_kwargs['train_batch'] = train_batch\n\n    if batch_size:\n      self._kfac_kwargs['batch_size'] = batch_size\n    elif 'batch_size' not in self._kfac_kwargs:\n      inferred_batch_size = train_batch[0].shape.as_list()[0]\n      if inferred_batch_size:\n        self._kfac_kwargs['batch_size'] = inferred_batch_size\n      else:\n        raise ValueError('Could not infer batch_size from the train_batch. '\n                         'Please provide it in the optimizer constructor or '\n                         'through register_train_batch.')\n\n  def minimize(self, loss, var_list, grad_loss=None, name=None):\n    if (self._config['use_passed_loss'] and 'loss' not in self._kfac_kwargs):\n      self._kfac_kwargs['loss'] = loss\n\n    return self._call_and_track_vars(\n        'minimize', loss, var_list=var_list, grad_loss=grad_loss, name=name)\n\n  def apply_gradients(self, grads_and_vars, name=None):\n    return self._call_and_track_vars(\n        'apply_gradients', grads_and_vars, name=name)\n\n  def get_updates(self, loss, params):\n    return [self.minimize(loss, params)]\n\n  def get_config(self):\n    config = self._config.copy()\n    for param in self._hyper:\n      config[param] = self._serialize_hyperparameter(param)\n    return config\n\n  def _create_optimizer(self):\n    \"\"\"Initializes the hyperparameters and sets the self._optimizer property.\"\"\"\n    if self._optimizer:\n      return\n    if not self._layer_collection:\n      self.register_layers(self._model, self._loss)\n\n    if self._config['adapt_damping']:\n      if 'train_batch' not in self._kfac_kwargs:\n        raise ValueError('Must provide a train_batch tuple to use adaptive '\n                         'damping. Use register_train_batch or pass it in '\n                         'during optimizer construction.')\n      if 'loss_fn' not in self._kfac_kwargs:\n        self._kfac_kwargs['loss_fn'] = utils.get_loss_fn(\n            self._model, self._loss, loss_weights=self._config['loss_weights'])\n\n    with tf.name_scope(self._name):\n      with tf.init_scope():\n        # \"iterations\" property will create iterations if necessary.\n        _ = self.iterations\n        self._create_hypers()\n\n    self._kfac_kwargs.update(self._hyper)\n    try:\n      # We use the TF 1 variable_scope instead of the TF 2 recommended\n      # name_scope because we need to recover the variables created in this\n      # scope, which is not possible with name_scope.\n      with tf.variable_scope(self._tf_var_scope):\n        self._optimizer = _KFAC_OPT_CLASS(\n            layer_collection=self._layer_collection, **self._kfac_kwargs)\n    except ValueError as e:\n      msg = str(e)\n      if re.search('Variable .* already exists', msg):\n        raise ValueError(\n            'You may have instantiated a KFAC Optimizer with the same name as '\n            'an existing one. Try resetting the default graph, instantiating '\n            'the optimizer with a different name, or changing the optimizer\\'s '\n            'name.\\nHere is the original ValueError:\\n ' + msg)\n      elif re.search('Found the following errors with variable registration'\n                     '.*gamma.*registered with wrong number of uses.*', msg):\n        # We don't regex the name batch_normalization because the user could\n        # have renamed the layer. We don't regex beta because they could have\n        # used BatchNorm without the shift.\n        raise ValueError(\n            'There may have been an issue registering BatchNormalization. Try '\n            'using tf.keras.backend.set_learning_phase before model '\n            'construction. An alternative solution is to use the unfused '\n            'batchnorm implementation (pass the argument fused=False to '\n            'BatchNormalization).\\nHere is the original ValueError:\\n ' + msg)\n      else:\n        raise e\n\n  def _call_and_track_vars(self, method_name, *args, **kwargs):\n    # We call _create_optimizer outside of the var_scope because\n    # _create_optimizer also opens the same variable_scope.\n    self._create_optimizer()\n    with tf.variable_scope(self._tf_var_scope):\n      kwargs['global_step'] = self.iterations\n      update_op = getattr(self._optimizer, method_name)(*args, **kwargs)\n\n    if not self._have_tracked_vars:\n      # We rely on the variables created in a deterministic order for get and\n      # set weights. Sorting the variables by name is not a reliable way to\n      # get a deterministic order due to the way TF KFAC assigns variable names.\n      for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,\n                                   scope=self._tf_var_scope):\n        backend.track_variable(var)\n        self.weights.append(var)\n      self._have_tracked_vars = True\n\n    return update_op\n\n  def _set_hyper(self, name, value):\n    \"\"\"Set hyper `name` to value. value must be numeric.\"\"\"\n    if self._hypers_created:\n      if not isinstance(self._hyper[name], tf.Variable):\n        raise AttributeError(\"Can't set attribute: {}\".format(name))\n      if not isinstance(value, numbers.Number):\n        raise ValueError('Dynamic reassignment only supports setting with a '\n                         'number. tf.Tensors and tf.Variables can only be used '\n                         'before the internal kfac optimizer is created.')\n      backend.set_value(self._hyper[name], value)\n    else:\n      super(Kfac, self)._set_hyper(name, value)\n"
  },
  {
    "path": "kfac/python/keras/saving_utils.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Saving/loading utilities for models created with the KFAC Optimizer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport json\nfrom absl import logging\nfrom tensorflow.python.keras.saving import hdf5_format\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.keras import optimizers\n\n# This optional h5py import allows users to import all of tensorflow_kfac\n# without h5py. The ImportError is raised manually if they try to use load_model\n# without h5py. This follows the Keras save.py style:\n# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/saving/save.py\ntry:\n  import h5py  # pylint: disable=g-import-not-at-top\nexcept ImportError:\n  h5py = None\n\n\ndef _compile_args_from_training_config(training_config, custom_objects=None):\n  \"\"\"Return model.compile arguments from training config.\"\"\"\n  if custom_objects is None:\n    custom_objects = {}\n\n  optimizer_config = training_config['optimizer_config']\n  optimizer = tf.keras.optimizers.deserialize(\n      optimizer_config, custom_objects=custom_objects)\n\n  # Recover loss functions and metrics.\n  loss_config = training_config['loss']  # Deserialize loss class.\n  if isinstance(loss_config, dict) and 'class_name' in loss_config:\n    loss_config = tf.keras.losses.get(loss_config)\n  loss = tf.nest.map_structure(\n      lambda obj: custom_objects.get(obj, obj), loss_config)\n  metrics = tf.nest.map_structure(\n      lambda obj: custom_objects.get(obj, obj), training_config['metrics'])\n  weighted_metrics = tf.nest.map_structure(\n      lambda obj: custom_objects.get(obj, obj),\n      training_config.get('weighted_metrics', None))\n  sample_weight_mode = training_config['sample_weight_mode']\n  loss_weights = training_config['loss_weights']\n\n  return dict(optimizer=optimizer,\n              loss=loss,\n              metrics=metrics,\n              weighted_metrics=weighted_metrics,\n              loss_weights=loss_weights,\n              sample_weight_mode=sample_weight_mode)\n\n\ndef load_model(filepath, custom_objects=None, optimizer_name=None):\n  \"\"\"Loads and compiles a Keras model saved as an HDF5 file.\n\n  Same as tf.keras.model.load_model, except it will always compile the model\n  and instantiate the Kfac optimizer correctly. If you do not want the model to\n  be compiled, or saved without the optimizer, use tf.keras.models.load_model\n  instead.\n\n  Example:\n  ```python:\n  import tensorflow as tf\n  import kfac\n\n  model = tf.keras.Model(...)\n  loss = tf.keras.losses.MSE()  # could be a serialized loss function\n  optimizer = kfac.keras.optimizers.Kfac(0.001, 0.01, model=model, loss=loss)\n  model.compile(optimizer, loss)\n  model.fit(...)\n  model.save('saved_model.hdf5')  # or use tf.keras.models.save_model\n  ...\n  loaded_model = kfac.keras.saving_utils.load_model('saved_model.hdf5')\n  loaded_model.fit(...)\n  ```\n\n  Args:\n    filepath: One of the following:\n        - String, path to the saved model\n        - `h5py.File` object from which to load the model\n    custom_objects: Optional dictionary mapping names (strings) to custom\n      classes or functions to be considered during deserialization. Kfac will\n      be added to this dictionary automatically.\n    optimizer_name: Optional string that specifies what variable scope you want\n      the KFAC variables to be created in. Useful if you have multiple KFAC\n      optimizers on one graph.\n\n  Raises:\n    ImportError: If h5py was not imported.\n\n  Returns:\n    A compiled Keras model with the Kfac optimizer correctly initialized.\n  \"\"\"\n  if h5py is None:\n    raise ImportError('`load_model` requires h5py.')\n  if not custom_objects:\n    custom_objects = {}\n  custom_objects['Kfac'] = optimizers.Kfac\n\n  should_open_file = not isinstance(filepath, h5py.File)\n  model_file = h5py.File(filepath, mode='r') if should_open_file else filepath\n\n  model = tf.keras.models.load_model(\n      model_file, custom_objects=custom_objects, compile=False)\n\n  # Code below is current as of 2019-06-20 and may break due to future changes.\n  # github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/saving/hdf5_format.py\n  try:\n    training_config = model_file.attrs.get('training_config')\n    if hasattr(training_config, 'decode'):\n      training_config = training_config.decode('utf-8')\n    if training_config is None:\n      raise ValueError('No training configuration found in save file, meaning '\n                       'the model was not compiled. Please use '\n                       'tf.keras.models.load_model instead.')\n    training_config = json.loads(training_config)\n\n    model.compile(**_compile_args_from_training_config(training_config,\n                                                       custom_objects))\n    model.optimizer.register_layers(model)\n    if optimizer_name:\n      model.optimizer.name = optimizer_name\n\n    if 'optimizer_weights' in model_file:\n      # Build train function (to get weight updates).\n      # Models that aren't graph networks must wait until they are called\n      # with data to _make_train_function() and so can't load optimizer\n      # weights.\n      model._make_train_function()  # pylint: disable=protected-access\n      opt_weight_vals = hdf5_format.load_optimizer_weights_from_hdf5_group(\n          model_file)\n      try:\n        model.optimizer.set_weights(opt_weight_vals)\n      except ValueError:\n        logging.warn('Error in loading the saved optimizer state. As a '\n                     'result, your model is starting with a freshly '\n                     'initialized optimizer.')\n  finally:\n    if should_open_file:\n      model_file.close()\n\n  return model\n"
  },
  {
    "path": "kfac/python/keras/utils.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utility Functions for using KFAC with Keras Objects.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import layer_collection as kfac_layer_collection\nfrom kfac.python.ops.tensormatch import tensorflow_graph_util\n\nlayers = tf.keras.layers\nlosses = tf.keras.losses\nactivations = tf.keras.activations\nK = tf.keras.backend\n\n# Added when serializing layer class names to prevent serialized class names\n# from clashing with user-defined layer names.\n_CLASS_NAME_PREFIX = 'kfac_class_'\n_KERAS_LOSS_TO_KFAC_REGISTER_FUNC = {\n    'sparsecategoricalcrossentropy':\n        kfac_layer_collection.LayerCollection\n        .register_softmax_cross_entropy_loss,\n    'categoricalcrossentropy':\n        kfac_layer_collection.LayerCollection\n        .register_softmax_cross_entropy_loss,\n    'binarycrossentropy':\n        kfac_layer_collection.LayerCollection\n        .register_sigmoid_cross_entropy_loss,\n}\n\n\ndef get_parent(node):\n  \"\"\"Retrieves the parent tf.Tensor of node in the computation graph.\n\n  Args:\n    node: A tf.Tensor.\n\n  Raises:\n   ValueError: If the node has more than one input op.\n   ValueError: If the node has more than one parent tf.Tensor.\n\n  Returns:\n    The parent tensor of the node on the computation graph.\n  \"\"\"\n  edge = tensorflow_graph_util.expand_inputs(node)\n  if len(edge) != 1:\n    raise ValueError('{} has more than one input op.'.format(node))\n  parent = tensorflow_graph_util.expand_inputs(edge[0])\n  if len(parent) != 1:\n    raise ValueError('{} has more than one parent tensor.'.format(node))\n  return parent[0]\n\n\ndef serialize_loss(loss):\n  \"\"\"Serialize a valid Keras Kfac loss argument.\"\"\"\n  def serialize(x):\n    return x if isinstance(x, six.string_types) else losses.serialize(x)\n\n  if not loss or isinstance(loss, six.string_types):\n    return loss\n  elif isinstance(loss, dict):\n    return {k: serialize(v) for k, v in loss.items()}\n  elif isinstance(loss, list):\n    return [serialize(v) for v in loss]\n  else:\n    return losses.serialize(loss)\n\n\ndef serialize_fisher_approx(fisher_approx):\n  \"\"\"Serialize a valid fisher approximation dict or list.\"\"\"\n  def serialize(key):\n    return (key if isinstance(key, six.string_types) else _CLASS_NAME_PREFIX +\n            key.__name__)\n\n  if isinstance(fisher_approx, dict):\n    fisher_approx = {serialize(k): v for k, v in fisher_approx.items()}\n  return fisher_approx\n\n\ndef _get_verified_dict(container, container_name, layer_names):\n  \"\"\"Verifies that loss_weights/fisher_approx conform to their specs.\"\"\"\n  if container is None or container == {}:  # pylint: disable=g-explicit-bool-comparison\n    # The explicit comparison prevents empty lists from passing.\n    return {}\n  elif isinstance(container, dict):\n    string_keys = {\n        str(k) for k in container if isinstance(k, six.string_types) and\n        not k.startswith(_CLASS_NAME_PREFIX)\n    }\n    if string_keys - set(layer_names):\n      raise ValueError('There is a {} without a matching layer'\n                       .format(container_name))\n    return container\n  elif isinstance(container, list):\n    if len(layer_names) != len(container):\n      raise ValueError('Number of {} and layers don\\'t match.'\n                       .format(container_name))\n    return dict(zip(layer_names, container))\n  else:\n    raise ValueError('{} must be a list or dict'.format(container_name))\n\n\ndef register_layer(layer_collection, layer, fisher_approx=None, **kwargs):\n  \"\"\"Get layer collection with all layers and loss registered.\n\n  Args:\n   layer_collection: LayerCollection object on which the layer will be\n     registered.\n   layer: Keras layer to register with the layer_collection.\n   fisher_approx: Option string specifying the fisher approximation type.\n   **kwargs: Keyword arguments to be forwarded to the layer registration\n     function.\n\n  Raises:\n   ValueError: If there is a layer with trainable parameters that isn't Conv1D,\n     Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding.\n   ValueError: If convolutional layers don't use the \"channels_last\" format.\n\n  Returns:\n    A kfac.LayerCollection with the model's layers and loss registered.\n  \"\"\"\n  # The inbound_nodes property is currently deprecated, but appears to be\n  # supported in non-eager TF 1.x. This may change.\n  # If there are multiple inbound_nodes, it means the model was used as a\n  # callable (i.e. y = model(x)). We assume the inputs/outputs from the call\n  # need to be registered and not the nodes from the original built model or\n  # any other previous calls, since layers can't be used multiple times\n  # (RNN-style) with Keras KFAC.\n  node = layer.inbound_nodes[-1]\n  pre_activation_output = node.output_tensors\n  if hasattr(layer, 'activation') and layer.activation != activations.linear:\n    pre_activation_output = get_parent(pre_activation_output)\n\n  # This will allow unsupported layers to be in our model as long as KFAC\n  # doesn't have to minimize with respect to those parameters.\n  if layer.count_params() and layer.trainable:\n    if any(isinstance(tensor, (list, tuple))\n           for tensor in (node.input_tensors, node.output_tensors)):\n      raise ValueError('Individual layers can only have 1 input_tensor and 1 '\n                       'output tensor. You are likely using an unsupported '\n                       'layer type. Error on layer {}'.format(layer))\n\n    weights = layer.trainable_weights\n    kwargs.update({\n        'inputs': node.input_tensors,\n        'outputs': pre_activation_output,\n        'params': weights if len(weights) > 1 else weights[0],\n        'approx': fisher_approx,\n    })\n\n    # TODO(b/133849249) Support RNNs and other shared weight layers.\n    if isinstance(layer, layers.Dense):\n      layer_collection.register_fully_connected(**kwargs)\n    elif isinstance(layer, layers.Embedding):\n      layer_collection.register_fully_connected(dense_inputs=False, **kwargs)\n    elif isinstance(layer, (layers.BatchNormalization,\n                            layers.LayerNormalization)):\n      if not layer.scale:\n        # With Batch/Layer Normalization, the user can specify if they want\n        # the input to be scaled and/or shifted after it is normalized.\n        raise ValueError('Kfac currently does not support batch/layer '\n                         'normalization with scale=False. Error on layer {}'\n                         .format(layer))\n      # Undo batchnorm by subtracting the shift and diving by scale.\n      kwargs['inputs'] = ((kwargs['outputs'] - weights[1]) / weights[0]\n                          if layer.center else kwargs['outputs'] / weights)\n      layer_collection.register_scale_and_shift(**kwargs)\n\n      # A learning_phase of 1 or 0 means it's been set. False means it hasn't.\n      is_phase_set = K.get_value(K.learning_phase()) != False  # pylint: disable=g-explicit-bool-comparison\n      if hasattr(layer, 'fused') and layer.fused and not is_phase_set:\n        # For the fused implementation of the BatchNormalization, there are\n        # two ops: one for training and one for inference. When the\n        # learning_phase is set, during layer creation, there is a\n        # tf_utils.smart_cond that will only create one of the ops. When the\n        # learning_phase is not set, it will create a tf.cond with both ops as\n        # branches. So, when learning_phase is not set, we must add a \"use\"\n        # for the gamma/beta variables to account for there being two ops that\n        # are consumers of the variables. Linked below is the smart_cond in\n        # BatchNormalization:\n        # https://github.com/tensorflow/tensorflow/blob/59217f581fdef4e5469a98b62e38f851eac88688/tensorflow/python/keras/layers/normalization.py#L513\n        # Updated 2019-06-22.\n        layer_collection._add_uses(weights, 1)  # pylint: disable=protected-access\n\n    elif all(hasattr(layer, a) for a in\n             ('strides', 'padding', 'dilation_rate')):\n      if layer.data_format != 'channels_last':\n        raise ValueError('KFAC currently only supports the \"channels_last\" '\n                         'data format for convolutional layers. Error on '\n                         'layer {}'.format(layer))\n\n      kwargs['padding'] = layer.padding.upper()\n      kwargs['strides'] = [1] + list(layer.strides) + [1]\n      kwargs['dilations'] = [1] + list(layer.dilation_rate) + [1]\n\n      if isinstance(layer, layers.Conv2D):\n        layer_collection.register_conv2d(**kwargs)\n      elif isinstance(layer, layers.Conv1D):\n        layer_collection.register_conv1d(**kwargs)\n      # Depthwise and Separable Conv2D are not supported yet because they are\n      # experimental in tensorflow_kfac.\n      else:\n        raise ValueError('Unsupported convolutional layer type: {}'\n                         .format(layer))\n        # TODO(b/133849240): Support registering any convolution type.\n    else:\n      raise ValueError('Unsupported layer type: {}'.format(layer))\n      # TODO(b/133849243): Support registering any generic layer type.\n\n\ndef register_loss(layer_collection, layer, loss, **kwargs):\n  \"\"\"Registers the loss with the layer for the layer_collection.\n\n  Args:\n   layer_collection: LayerCollection object on which the layer and loss will\n     be registered.\n   layer: Keras layer whose outputs will be used with the loss function.\n   loss: Keras (normal or serialized) loss function. Currently,\n     sparse/normal categorical/binary cross entropy and MSE are supported.\n   **kwargs: Keyword arguments to be forwarded to the function that\n     registers the loss. A couple of notable ones include coeff (the weight of\n     the loss) and seed (the seed used when sampling from the output\n     distribution).\n\n  Raises:\n   ValueError: If a loss function other than MSE and cross entropy\n     variants is used.\n\n  Raises:\n   ValueError: If a loss function other than MSE and cross entropy\n     variants is used.\n  \"\"\"\n  node = layer.inbound_nodes[-1]\n  pre_activation_output = node.output_tensors\n  if hasattr(layer, 'activation') and layer.activation != activations.linear:\n    pre_activation_output = get_parent(pre_activation_output)\n\n  # A Keras loss can be a callable class or a function. Their serialized\n  # forms differ. The logic below normalizes these difference. This will\n  # not work for custom losses (we do not intend to support custom loss\n  # functions for now).\n  if not isinstance(loss, six.string_types):\n    loss = losses.serialize(loss)\n  if isinstance(loss, dict):\n    loss = loss['class_name']\n  loss = loss.replace('_', '').lower()\n\n  if loss in ('meansquarederror', 'mse'):\n    # We use the actual output here instead of the pre-activations because\n    # MSE is computed with the output. For the logit loss functions,\n    # tensorflow_kfac needs the pre-activations.\n    layer_collection.register_squared_error_loss(layer.output, **kwargs)\n  elif loss in _KERAS_LOSS_TO_KFAC_REGISTER_FUNC:\n    _KERAS_LOSS_TO_KFAC_REGISTER_FUNC[loss](\n        layer_collection, logits=pre_activation_output, **kwargs)\n  else:\n    raise ValueError('Unsupported loss function: {}'.format(loss))\n\n\ndef get_layer_collection(model,\n                         loss=None,\n                         loss_weights=None,\n                         fisher_approx=None,\n                         layer_collection=None,\n                         seed=None):\n  \"\"\"Get layer collection with all layers and loss registered.\n\n  Args:\n   model: Keras model whose layers to register. Currently, Conv1D,\n     Conv2D, Dense, BatchNormalization, LayerNormalization and Embedding layers\n     are supported in a Functional or Sequential model. Other layer types are\n     supported as long as they aren't trainable (or don't have weights). Nested\n     models are supported.\n   loss: Optional Keras (normal or serialized) loss function. Could be a list or\n     a dictionary mapping layer names to (normal or serialized) loss functions.\n     if there are multiple losses Currently, sparse/normal categorical/binary\n     cross entropy and MSE are supported. You must register at least one loss\n     with the layer collection before it can be used.\n   loss_weights: An optional list of coefficients or a dictionary mapping\n     layer names to the coefficient for each loss function. If it is a list,\n     there must be a the same number of coefficients as loss functions. If\n     it is a dictionary and a coefficient is not given for a loss function,\n     a coefficient of 1.0 will be used.\n   fisher_approx: An optional list of approximations or a dictionary mapping\n     layer name/class to fisher approximation type. If it is a list, there must\n     be the same number of approximations as there are layers with trainable\n     parameters. For each layer, the approximation is determined as follows:\n     if fisher_approx is a dictionary, first we check if the name is in the\n     dict, if it isn't found the layer class is checked, if that isn't found\n     the default is used. When fisher_approx is a list, the order of the\n     approximations must match the order of the layers with trainable parameters\n     given by model.layers. None is a valid dict/list entry and indicates to use\n     the default approximation for that layer.\n   layer_collection: Optional LayerCollection object on which the model and loss\n     will be registered.\n   seed: Optional integer specifing the TensorFlow random seed. To get\n     deterministic behaviour, the seed needs to be set because the targets\n     are sampled to approximate the fisher.\n\n  Raises:\n   ValueError: If there is a layer with trainable parameters that isn't Conv1D,\n     Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding.\n   ValueError: If a loss function other than MSE and cross entropy\n     variants is used.\n   ValueError: If there isn't a one-to-one correspondence between\n     loss/loss_weights and output layers, or if loss_weights isn't a list/dict.\n   ValueError: If convolutional layers don't use the \"channels_last\" format.\n\n  Returns:\n    A kfac.LayerCollection with the model's layers and loss registered.\n  \"\"\"\n  if not layer_collection:\n    layer_collection = kfac_layer_collection.LayerCollection()\n\n  if not loss:\n    loss = {}\n  elif isinstance(loss, dict):\n    if set(model.output_names) != set(loss.keys()):\n      raise ValueError('Output layer names and loss dict keys don\\'t match'\n                       ' \\nmodel.output_names: {} \\nloss dict keys: {}'\n                       .format(model.output_names, loss.keys()))\n  elif isinstance(loss, list):\n    if len(model.output_names) != len(loss):\n      raise ValueError('Number of loss dict items doesn\\'t match number of '\n                       'output layers. \\nmodel.output_names: {} \\nloss list: '\n                       '{}'.format(model.output_names, loss))\n    loss = dict(zip(model.output_names, loss))\n  else:\n    if len(model.output_names) > 1:\n      raise ValueError('More output layers than losses. \\n'\n                       'model.output_names: {} \\nloss: {}'\n                       .format(model.output_names, loss))\n    # When the model is used as a callable, the model's output_names may not\n    # match the actual output layer's name. In the one output case, we always\n    # want the last layer, so we use the last layer's name.\n    loss = {model.layers[-1].name: loss}\n\n  # We want to do a left-to-right depth-first traversal of the model to get the\n  # correct flattened order of the layers. The order only matters for the\n  # fisher_approx in list form.\n  flattened_layers = []\n  layer_stack = model.layers[::-1]\n  while layer_stack:\n    layer = layer_stack.pop()\n    if hasattr(layer, 'layers'):\n      if layer.name in loss:\n        if len(layer.output_names) > 1:\n          raise ValueError('Nested models with multiple outputs are '\n                           'unsupported.')\n        loss[layer.output_names[0]] = loss.pop(layer.name)\n      layer_stack += layer.layers[::-1]\n    else:\n      flattened_layers.append(layer)\n\n  trainable_layer_names = [l.name for l in flattened_layers if\n                           l.count_params() and l.trainable]\n  fisher_approx = _get_verified_dict(fisher_approx, 'fisher_approx',\n                                     trainable_layer_names)\n  # The Optimizer class passes in a serialized fisher_approx dictionary, but the\n  # user may not. We serialize it so we can use it uniformly.\n  fisher_approx = serialize_fisher_approx(fisher_approx)\n  loss_weights = _get_verified_dict(loss_weights, 'loss_weights',\n                                    model.output_names)\n\n  for layer in flattened_layers:\n    if layer.name in fisher_approx:\n      approx = fisher_approx[layer.name]\n    else:\n      approx = fisher_approx.get(\n          _CLASS_NAME_PREFIX + layer.__class__.__name__, None)\n\n    register_layer(layer_collection, layer, fisher_approx=approx)\n\n    if layer.name in loss:\n      register_loss(layer_collection=layer_collection,\n                    layer=layer,\n                    loss=loss[layer.name],\n                    coeff=loss_weights.get(layer.name, 1.0),\n                    seed=seed)\n\n  return layer_collection\n\n\ndef get_loss_fn(model,\n                loss,\n                training=None,\n                loss_weights=None,\n                reduce_fn=tf.reduce_mean,\n                name='loss'):\n  \"\"\"Creates a loss function to be used for KFAC's adaptive damping.\n\n  This allows Keras KFAC to automatically create the loss function to use\n  for adaptive_damping. This function would also be useful for a custom training\n  loop that uses adaptive_damping.\n\n  The returned loss function currently does not support masks or sample_weights.\n\n  Currently, if you use a categorical crossentropy loss, due to the\n  implementation of tf.keras.losses.*_crossentropy, it will  grab the logits\n  whether you use a softmax at the end of your model or not. This is true as of\n  August 1, 2019. Code below:\n  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py#L4322\n\n  Args:\n    model: tf.keras.Model model that will be used with the inputs to the\n      returned loss_fn.\n    loss: Potentially serialized tf.keras.losses.* loss function(s)/class(es).\n      If the model has multiple outputs, this must be a list of losses that\n      matches the order of model.outputs, or a dictionary with names matching\n      output_names. Must accept kwargs y_pred and y_true. Note that if your\n      model's output are logits, you should pass a callable Keras with\n      from_logits=True. This function could be a non-Keras loss, but it is\n      untested in this case.\n    training: Boolean indicating whether or not the loss is used in training or\n      test time. This is necessary to set the proper mode for batch norm and\n      dropout layers. If None then falls back to Keras behavior of calling the\n      model without passing a value for training.\n    loss_weights: If you have multiple losses, a list or dictionaryof weights\n      for each loss. A default value of 1.0 is given for losses that don't have\n      a weight when a dictionary is passed.\n    reduce_fn: The function that will be used to aggregate the loss tensor.\n      tf.reduce_mean by default. You may replace this with the identity if your\n      loss does a reduction by default. Depending on how you compute your loss\n      in a distributed setting, you may want to modify this function (for\n      example, if you sum across replicas, then the reduce_fn might be\n      lambda x: tf.reduce_sum(x) * (1.0 / global_batch_size).\n    name: Name scope for the loss_fn ops.\n\n  Raises:\n    ValueError: If the loss is a dictionary.\n\n  Returns:\n    A function that takes inputs and optionally a prediction and will return\n    a loss. This can be used as the KFAC loss_fn for adaptive damping.\n  \"\"\"\n  if isinstance(loss, six.string_types):\n    loss = losses.deserialize(loss)\n  elif isinstance(loss, dict):\n    loss = [loss[n] for n in model.output_names]\n\n  if isinstance(loss, list):\n    loss = [losses.deserialize(l) if isinstance(l, six.string_types) else l\n            for l in loss]\n\n  if isinstance(loss_weights, dict):\n    loss_weights = [loss_weights.get(n, 1.0) for n in model.output_names]\n\n  def loss_fn(inputs, prediction=None):\n    \"\"\"Computes loss for a model given inputs.\n\n    This function is meant to be used with K-FAC's adaptive damping, which is\n    why the prediction is optional (since K-FAC wants to compute the loss just\n    given inputs).\n\n    Args:\n      inputs: A tuple with (model_input(s), label(s)), where both elements are\n        tensors or lists/tuples of tensors.\n      prediction: The output of the model given the inputs. If this isn't,\n        provided, the prediction will be computed via\n        prediction = model(inputs[0])\n\n    Returns:\n      A tensor with the total reduced loss including regularization and other\n      layer specific losses.\n    \"\"\"\n    with tf.name_scope(name):\n      x, y = inputs\n      if prediction is None:\n        if training is not None:\n          prediction = model(x, training=training)\n        else:\n          prediction = model(x)\n\n      if isinstance(prediction, (tuple, list)):\n        reduced_losses = [reduce_fn(fn(y_pred=pred_i, y_true=y_i))\n                          for fn, pred_i, y_i in zip(loss, prediction, y)]\n        if loss_weights:\n          reduced_losses = [l * w for l, w in zip(reduced_losses, loss_weights)]\n        total_loss = tf.add_n(reduced_losses)\n      else:\n        total_loss = reduce_fn(loss(y_pred=prediction, y_true=y))\n\n      # Adds regularization penalties and other custom layer specific losses.\n      if model.losses:\n        total_loss += tf.add_n(model.losses)\n\n    return total_loss\n\n  return loss_fn\n\n"
  },
  {
    "path": "kfac/python/kernel_tests/data_reader_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for CachedDataReader class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops.kfac_utils import data_reader\n\n\nclass DataReaderTest(tf.test.TestCase):\n\n  def test_read_batch(self):\n    max_batch_size = 10\n    batch_size_schedule = [2, 4, 6, 8]\n    data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.)\n    var_data = data_reader.CachedDataReader(\n        (data_set,), max_batch_size)\n    cur_batch_size = tf.placeholder(\n        shape=(), dtype=tf.int32, name='cur_batch_size')\n    # Force create the ops\n    data = var_data(cur_batch_size)[0]\n    with self.test_session() as sess:\n      sess.run(tf.global_variables_initializer())\n      coord = tf.train.Coordinator()\n      tf.train.start_queue_runners(sess=sess, coord=coord)\n      for batch_size in batch_size_schedule:\n        data_ = sess.run(\n            data, feed_dict={cur_batch_size: batch_size})\n        self.assertEqual(len(data_), batch_size)\n        self.assertEqual(len(data_[0]), 784)\n\n  def test_cached_batch(self):\n    max_batch_size = 100\n    data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.)\n    var_data = data_reader.CachedDataReader(\n        (data_set,), max_batch_size)\n    cur_batch_size = tf.placeholder(\n        shape=(), dtype=tf.int32, name='cur_batch_size')\n    # Force create the ops\n    data = var_data(cur_batch_size)[0]\n    with self.test_session() as sess:\n      sess.run(tf.global_variables_initializer())\n      coord = tf.train.Coordinator()\n      tf.train.start_queue_runners(sess=sess, coord=coord)\n      data_ = sess.run(data, feed_dict={cur_batch_size: 25})\n      stored_data_ = sess.run(var_data.cached_batch)[0]\n      self.assertListEqual(list(data_[1]), list(stored_data_[1]))\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/estimator_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for kfac.estimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import estimator\nfrom kfac.python.ops import fisher_factors as ff\nfrom kfac.python.ops import layer_collection as lc\nfrom kfac.python.ops import utils\n\n\n# We need to set these constants since the numerical values used in the tests\n# were chosen when these used to be the defaults.\nff.set_global_constants(zero_debias=False)\n\n\n_ALL_ESTIMATION_MODES = [\"gradients\", \"empirical\", \"curvature_prop\", \"exact\"]\n\n\nclass EstimatorTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._graph = tf.Graph()\n    with self._graph.as_default():\n      self.layer_collection = lc.LayerCollection()\n\n      self.inputs = tf.random_normal((2, 2), dtype=tf.float32)\n      self.weights = tf.get_variable(\"w\", shape=(2, 2), dtype=tf.float32)\n      self.bias = tf.get_variable(\n          \"b\", initializer=tf.zeros_initializer(), shape=(2, 1))\n      self.output = tf.matmul(self.inputs, self.weights) + self.bias\n\n      # Only register the weights.\n      self.layer_collection.register_fully_connected(\n          params=(self.weights,), inputs=self.inputs, outputs=self.output)\n\n      self.outputs = tf.tanh(self.output)\n      self.targets = tf.zeros_like(self.outputs)\n      self.layer_collection.register_categorical_predictive_distribution(\n          logits=self.outputs, targets=self.targets)\n\n  def testEstimatorInitManualRegistration(self):\n    with self._graph.as_default():\n      # We should be able to build an estimator for only the registered vars.\n      estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection\n      )\n\n      # Check that we throw an error if we try to build an estimator for vars\n      # that were not manually registered.\n      with self.assertRaises(ValueError):\n        est = estimator.FisherEstimatorRoundRobin(\n            variables=[self.weights, self.bias],\n            cov_ema_decay=0.1,\n            damping=0.2,\n            layer_collection=self.layer_collection\n        )\n        est.make_vars_and_create_op_thunks()\n\n      # Check that we throw an error if we don't include registered variables,\n      # i.e. self.weights\n      with self.assertRaises(ValueError):\n        est = estimator.FisherEstimatorRoundRobin(\n            variables=[],\n            cov_ema_decay=0.1,\n            damping=0.2,\n            layer_collection=self.layer_collection)\n        est.make_vars_and_create_op_thunks()\n\n  @tf.test.mock.patch.object(utils.SubGraph, \"variable_uses\", return_value=42)\n  def testVariableWrongNumberOfUses(self, mock_uses):\n    with self.assertRaises(ValueError):\n      est = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection)\n      est.make_vars_and_create_op_thunks()\n\n  def testInvalidEstimationMode(self):\n    with self.assertRaises(ValueError):\n      est = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection,\n          estimation_mode=\"not_a_real_mode\")\n      est.make_vars_and_create_op_thunks()\n\n  def testGradientsModeBuild(self):\n    with self._graph.as_default():\n      est = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection,\n          estimation_mode=\"gradients\")\n      est.make_vars_and_create_op_thunks()\n\n  def testEmpiricalModeBuild(self):\n    with self._graph.as_default():\n      est = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection,\n          estimation_mode=\"empirical\")\n      est.make_vars_and_create_op_thunks()\n\n  def testCurvaturePropModeBuild(self):\n    with self._graph.as_default():\n      est = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection,\n          estimation_mode=\"curvature_prop\")\n      est.make_vars_and_create_op_thunks()\n\n  def testExactModeBuild(self):\n    with self._graph.as_default():\n      est = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          cov_ema_decay=0.1,\n          damping=0.2,\n          layer_collection=self.layer_collection,\n          estimation_mode=\"exact\")\n      est.make_vars_and_create_op_thunks()\n\n  def test_cov_update_thunks(self):\n    \"\"\"Ensures covariance update ops run once per global_step.\"\"\"\n    with self._graph.as_default(), self.test_session() as sess:\n      fisher_estimator = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          layer_collection=self.layer_collection,\n          damping=0.2,\n          cov_ema_decay=0.0)\n\n      # Construct an op that executes one covariance update per step.\n      global_step = tf.train.get_or_create_global_step()\n      (cov_variable_thunks, cov_update_op_thunks, _,\n       _) = fisher_estimator.create_ops_and_vars_thunks()\n      for thunk in cov_variable_thunks:\n        thunk()\n      cov_matrices = [\n          fisher_factor.cov\n          for fisher_factor in self.layer_collection.get_factors()\n      ]\n      cov_update_op = tf.case([(tf.equal(global_step, i), thunk)\n                               for i, thunk in enumerate(cov_update_op_thunks)])\n      increment_global_step = global_step.assign_add(1)\n\n      sess.run(tf.global_variables_initializer())\n      initial_cov_values = sess.run(cov_matrices)\n\n      # Ensure there's one update per covariance matrix.\n      self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))\n\n      # Test is no-op if only 1 covariance matrix.\n      assert len(cov_matrices) > 1\n\n      for i in range(len(cov_matrices)):\n        # Compare new and old covariance values\n        new_cov_values = sess.run(cov_matrices)\n        is_cov_equal = [\n            np.allclose(initial_cov_value, new_cov_value)\n            for (initial_cov_value,\n                 new_cov_value) in zip(initial_cov_values, new_cov_values)\n        ]\n        num_cov_equal = sum(is_cov_equal)\n\n        # Ensure exactly one covariance matrix changes per step.\n        self.assertEqual(num_cov_equal, len(cov_matrices) - i)\n\n        # Run all covariance update ops.\n        sess.run(cov_update_op)\n        sess.run(increment_global_step)\n\n  def test_round_robin_placement(self):\n    \"\"\"Check if the ops and variables are placed on devices correctly.\"\"\"\n    with self._graph.as_default():\n      fisher_estimator = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          layer_collection=self.layer_collection,\n          damping=0.2,\n          cov_ema_decay=0.0,\n          cov_devices=[\"/cpu:{}\".format(i) for i in range(2)],\n          inv_devices=[\"/cpu:{}\".format(i) for i in range(2)])\n\n      # Construct an op that executes one covariance update per step.\n      (cov_update_thunks,\n       inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(\n           scope=\"test\")\n      cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)\n      inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)\n      self.assertEqual(cov_update_ops[0].device, \"/device:CPU:0\")\n      self.assertEqual(cov_update_ops[1].device, \"/device:CPU:1\")\n      self.assertEqual(inv_update_ops[0].device, \"/device:CPU:0\")\n      self.assertEqual(inv_update_ops[1].device, \"/device:CPU:1\")\n      cov_matrices = [\n          fisher_factor._cov._var\n          for fisher_factor in self.layer_collection.get_factors()\n      ]\n      inv_matrices = [\n          matrix\n          for fisher_factor in self.layer_collection.get_factors()\n          for matrix in fisher_factor._matpower_by_exp_and_damping.values()\n      ]\n      self.assertEqual(cov_matrices[0].device, \"/device:CPU:0\")\n      self.assertEqual(cov_matrices[1].device, \"/device:CPU:1\")\n      # Inverse matrices need to be explicitly placed.\n      self.assertEqual(inv_matrices[0].device, \"\")\n      self.assertEqual(inv_matrices[1].device, \"\")\n\n  def test_inv_update_thunks(self):\n    \"\"\"Ensures inverse update ops run once per global_step.\"\"\"\n    with self._graph.as_default(), self.test_session() as sess:\n      fisher_estimator = estimator.FisherEstimatorRoundRobin(\n          variables=[self.weights],\n          layer_collection=self.layer_collection,\n          damping=0.2,\n          cov_ema_decay=0.0)\n\n      # Construct op that updates one inverse per global step.\n      global_step = tf.train.get_or_create_global_step()\n      (cov_variable_thunks, _, inv_variable_thunks,\n       inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()\n      for thunk in cov_variable_thunks:\n        thunk()\n      for thunk in inv_variable_thunks:\n        thunk()\n      inv_matrices = [\n          matrix\n          for fisher_factor in self.layer_collection.get_factors()\n          for matrix in fisher_factor._matpower_by_exp_and_damping.values()\n      ]\n      inv_update_op = tf.case([(tf.equal(global_step, i), thunk)\n                               for i, thunk in enumerate(inv_update_op_thunks)])\n      increment_global_step = global_step.assign_add(1)\n\n      sess.run(tf.global_variables_initializer())\n      initial_inv_values = sess.run(inv_matrices)\n\n      # Ensure there's one update per inverse matrix. This is true as long as\n      # there's no fan-in/fan-out or parameter re-use.\n      self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))\n\n      # Test is no-op if only 1 invariance matrix.\n      assert len(inv_matrices) > 1\n\n      # Assign each covariance matrix a value other than the identity. This\n      # ensures that the inverse matrices are updated to something different as\n      # well.\n      sess.run([\n          fisher_factor._cov.add_to_average(\n              2 * tf.eye(int(fisher_factor._cov_shape[0])))\n          for fisher_factor in self.layer_collection.get_factors()\n      ])\n\n      for i in range(len(inv_matrices)):\n        # Compare new and old inverse values\n        new_inv_values = sess.run(inv_matrices)\n        is_inv_equal = [\n            np.allclose(initial_inv_value, new_inv_value)\n            for (initial_inv_value,\n                 new_inv_value) in zip(initial_inv_values, new_inv_values)\n        ]\n        num_inv_equal = sum(is_inv_equal)\n\n        # Ensure exactly one inverse matrix changes per step.\n        self.assertEqual(num_inv_equal, len(inv_matrices) - i)\n\n        # Run all inverse update ops.\n        sess.run(inv_update_op)\n        sess.run(increment_global_step)\n\n\nif __name__ == \"__main__\":\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/graph_search_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for tensormatch/graph_search.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import fisher_blocks as fb\nfrom kfac.python.ops import layer_collection as lc\nfrom kfac.python.ops import optimizer\n\nfrom kfac.python.ops.tensormatch import graph_search as gs\n\n\ndef _build_model():\n  w = tf.get_variable('W', [10, 10])\n  b_1 = tf.get_variable('b_1', [\n      10,\n  ])\n  b_0 = tf.get_variable('b_0', [\n      10,\n  ])\n  x = tf.placeholder(tf.float32, shape=(32, 10))\n  y = tf.placeholder(tf.float32, shape=(32, 10))\n\n  pre_bias_0 = tf.matmul(x, w)\n  pre_bias_1 = tf.matmul(y, w)\n\n  out_0 = pre_bias_0 + b_0  # pylint: disable=unused-variable\n  out_1 = pre_bias_1 + b_1  # pylint: disable=unused-variable\n\n  tensor_dict = {}\n\n  tensor_dict['w'] = w\n  tensor_dict['b_0'] = b_0\n  tensor_dict['b_1'] = b_1\n  tensor_dict['x'] = x\n  tensor_dict['y'] = y\n  tensor_dict['pre_bias_0'] = pre_bias_0\n  tensor_dict['pre_bias_1'] = pre_bias_1\n  tensor_dict['out_0'] = out_0\n  tensor_dict['out_1'] = out_1\n\n  return tensor_dict\n\n\ndef _build_mock_records():\n  tensor_dict = _build_model()\n  weight_record = gs.MatchRecord(\n      record_type=gs.RecordType.fully_connected,\n      params=tensor_dict['w'],\n      tensor_set={\n          tensor_dict['x'], tensor_dict['w'], tensor_dict['pre_bias_0']\n      })\n  weight_and_bias_0_record = gs.MatchRecord(\n      record_type=gs.RecordType.fully_connected,\n      params=(tensor_dict['w'], tensor_dict['b_0']),\n      tensor_set={\n          tensor_dict['x'], tensor_dict['w'], tensor_dict['pre_bias_0'],\n          tensor_dict['b_0'], tensor_dict['out_0']\n      })\n  bias_0_record = gs.MatchRecord(\n      record_type=gs.RecordType.fully_connected,\n      params=tensor_dict['b_0'],\n      tensor_set={\n          tensor_dict['pre_bias_0'], tensor_dict['b_0'], tensor_dict['out_0']\n      })\n  weight_and_bias_1_record = gs.MatchRecord(\n      record_type=gs.RecordType.fully_connected,\n      params=(tensor_dict['w'], tensor_dict['b_1']),\n      tensor_set={\n          tensor_dict['y'], tensor_dict['w'], tensor_dict['pre_bias_1'],\n          tensor_dict['b_1'], tensor_dict['out_1']\n      })\n  record_list_dict = collections.defaultdict(list)\n  for record in [\n      weight_record, weight_and_bias_0_record, bias_0_record,\n      weight_and_bias_1_record\n  ]:\n    record_list_dict[record.params].append(record)\n  return tensor_dict, dict(record_list_dict)\n\n\ndef assert_fisher_blocks_match(test_case, layer_collection_a,\n                               layer_collection_b):\n  \"\"\"Check that two `LayerCollection`s have matching fisher_blocks.\"\"\"\n\n  fisher_blocks_a = layer_collection_a.fisher_blocks\n  fisher_blocks_b = layer_collection_b.fisher_blocks\n\n  test_case.assertSetEqual(\n      set(fisher_blocks_a.keys()), set(fisher_blocks_b.keys()))\n\n  for parameters, block_a in fisher_blocks_a.items():\n    block_b = fisher_blocks_b[parameters]\n    test_case.assertEqual(type(block_a), type(block_b))\n    if hasattr(block_a, '_inputs'):\n      test_case.assertEqual(block_a._inputs, block_b._inputs)  # pylint: disable=protected-access\n      test_case.assertEqual(block_a._outputs, block_b._outputs)  # pylint: disable=protected-access\n    else:\n      test_case.assertEqual(block_a._params, block_b._params)  # pylint: disable=protected-access\n\n\ndef sparse_softmax_cross_entropy(labels,\n                                 logits,\n                                 num_classes,\n                                 weights=1.0,\n                                 label_smoothing=0.1):\n  \"\"\"Softmax cross entropy with example weights, label smoothing.\"\"\"\n  assert_valid_label = [\n      tf.assert_greater_equal(labels, tf.cast(0, dtype=tf.int64)),\n      tf.assert_less(labels, tf.cast(num_classes, dtype=tf.int64))\n  ]\n  with tf.control_dependencies(assert_valid_label):\n    labels = tf.reshape(labels, [-1])\n    dense_labels = tf.one_hot(labels, num_classes)\n    loss = tf.losses.softmax_cross_entropy(\n        onehot_labels=dense_labels,\n        logits=logits,\n        weights=weights,\n        label_smoothing=label_smoothing)\n  return loss\n\n\nclass GraphSearchTestCase(tf.test.TestCase):\n\n  def testRegisterLayers(self):\n    \"\"\"Ensure graph search can find a single layer network.\"\"\"\n    with tf.Graph().as_default():\n      layer_collection = lc.LayerCollection()\n\n      # Construct a 1-layer model.\n      inputs = tf.ones((2, 1)) * 2\n      weights = tf.get_variable(\n          'w',\n          shape=(1, 1),\n          dtype=tf.float32,\n          initializer=tf.random_normal_initializer)\n      bias = tf.get_variable(\n          'b', initializer=tf.zeros_initializer(), shape=(1, 1))\n      non_variable_bias = tf.ones((1, 1))\n      output = tf.matmul(inputs, weights) + bias + non_variable_bias\n      logits = tf.tanh(output)\n\n      # Register posterior distribution. Graph search will infer variables\n      # needed to construct this.\n      layer_collection.register_categorical_predictive_distribution(logits)\n\n      # Register variables.\n      gs.register_layers(layer_collection, tf.trainable_variables())\n\n      # Ensure 1-layer got registered.\n      self.assertEqual(\n          [(weights, bias)],\n          list(layer_collection.fisher_blocks.keys()))\n      self.assertEqual(1, len(layer_collection.losses))\n\n  def test_register_records_order(self):\n    \"\"\"Ensure records are always registered in the same order.\"\"\"\n    with tf.Graph().as_default():\n      data = {'inputs': tf.zeros([10, 4]), 'outputs': tf.zeros([10, 3]),\n              'dense_inputs': True}\n      params1 = tf.get_variable('w1', [4, 3])\n      record1 = gs.MatchRecord(\n          gs.RecordType.fully_connected, params1, set(), data=data)\n\n      params2 = (tf.get_variable('w2', [4, 3]),\n                 tf.get_variable('b2', [3]))\n      record2 = gs.MatchRecord(\n          gs.RecordType.fully_connected, params2, set(), data=data)\n\n      # Create a dict of records.\n      records = collections.OrderedDict()\n      records[params1] = [record1]\n      records[params2] = [record2]\n\n      # Register variables.\n      layer_collection = lc.LayerCollection(name='lc1')\n      gs.register_records(layer_collection, records)\n\n      # Ensure order matches lexicographic order.\n      self.assertEqual([params2, params1],\n                       list(layer_collection.fisher_blocks.keys()))\n\n      # Create a dict of records in a different order.\n      records = collections.OrderedDict()\n      records[params2] = [record2]\n      records[params1] = [record1]\n\n      # Register variables.\n      layer_collection = lc.LayerCollection(name='lc2')\n      gs.register_records(layer_collection, records)\n\n      # Ensure order matches lexicographic order.\n      self.assertEqual([params2, params1],\n                       list(layer_collection.fisher_blocks.keys()))\n\n  def test_multitower_examples_model(self):\n    \"\"\"Ensure graph search runs properly on a multitower setup.\n\n    This test uses linear_model from examples/convnets.\n    \"\"\"\n    with tf.Graph().as_default():\n      def linear_model(images, labels, num_classes):\n        \"\"\"Creates a linear model.\n\n        Args:\n          images: The input image tensors, a tensor of size\n              (batch_size x height_in x width_in x channels).\n          labels: The sparse target labels, a tensor of size (batch_size x 1).\n          num_classes: The number of classes, needed for one-hot encoding (int).\n\n        Returns:\n          loss: The total loss for this model (0-D tensor).\n          logits: Predictions for this model (batch_size x num_classes).\n        \"\"\"\n        images = tf.reshape(images, [images.shape[0], -1])\n        logits = tf.layers.dense(images, num_classes, name='logits')\n        loss = sparse_softmax_cross_entropy(labels, logits, num_classes)\n        return loss, logits\n\n      model = linear_model\n      layer_collection = lc.LayerCollection()\n      num_towers = 2\n      batch_size = num_towers\n      num_classes = 2\n\n      # Set up data.\n      images = tf.random_uniform(shape=[batch_size, 32, 32, 1])\n      labels = tf.random_uniform(\n          dtype=tf.int64, shape=[batch_size, 1], maxval=num_classes)\n\n      tower_images = tf.split(images, num_towers)\n      tower_labels = tf.split(labels, num_towers)\n\n      # Build model.\n      losses = []\n      logits = []\n      for tower_id in range(num_towers):\n        tower_name = 'tower%d' % tower_id\n        with tf.name_scope(tower_name):\n          with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):\n            current_loss, current_logits = model(\n                tower_images[tower_id], tower_labels[tower_id], num_classes + 1)\n            layer_collection.register_categorical_predictive_distribution(\n                current_logits, name='logits')\n            losses.append(current_loss)\n            logits.append(current_logits)\n\n      # Run the graph scanner.\n      with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):\n        gs.register_layers(layer_collection, tf.trainable_variables())\n      self.assertEqual(len(layer_collection.fisher_blocks), 1)\n      fisher_block = list(layer_collection.fisher_blocks.values())[0]\n      self.assertIsInstance(fisher_block, fb.FullyConnectedKFACBasicFB)\n      self.assertEqual(fisher_block.num_registered_towers, num_towers)\n\n      global_step = tf.train.get_or_create_global_step()\n      opt = optimizer.KfacOptimizer(\n          learning_rate=0.1,\n          cov_ema_decay=0.1,\n          damping=0.1,\n          layer_collection=layer_collection,\n          momentum=0.1)\n      cost = tf.reduce_mean(losses)\n      (cov_update_thunks,\n       inv_update_thunks) = opt.make_vars_and_create_op_thunks()\n      cov_update_op = tf.group(*(thunk() for thunk in cov_update_thunks))\n      inv_update_op = tf.group(*(thunk() for thunk in inv_update_thunks))\n      train_op = opt.minimize(cost, global_step=global_step)\n      init = tf.global_variables_initializer()\n\n      # Run a single training step.\n      with self.test_session() as sess:\n        sess.run(init)\n        sess.run([cov_update_op])\n        sess.run([inv_update_op])\n        sess.run([train_op])\n\n  def test_multitower_multi_loss_function(self):\n    \"\"\"Test multitower setup with multiple loss functions.\n\n    The automatic graph scanner should handle multiple loss functions per tower,\n    as long as they're registered in a consistent order.\n    \"\"\"\n    with tf.Graph().as_default():\n      w_1 = tf.get_variable('w_1', shape=[10, 10])\n      b_1 = tf.get_variable('b_1', shape=[10])\n      w_2 = tf.get_variable('w_2', shape=[10, 10])\n      b_2 = tf.get_variable('b_2', shape=[10])\n      layer_collection = lc.LayerCollection()\n      layer_collection_manual = lc.LayerCollection()\n      for tower_num in range(5):\n        x = tf.placeholder(tf.float32, shape=(32, 10))\n        logits_1 = tf.matmul(x, w_1) + b_1\n        logits_2 = tf.matmul(x, w_2) + b_2\n        if tower_num == 0:\n          reuse = False\n        else:\n          reuse = True\n        with tf.variable_scope('tower%d' % tower_num, reuse=reuse):\n          for l in [layer_collection, layer_collection_manual]:\n            l.register_categorical_predictive_distribution(\n                logits_1, name='loss_1')\n            l.register_categorical_predictive_distribution(\n                logits_2, name='loss_2')\n          layer_collection_manual.register_fully_connected((w_1, b_1), x,\n                                                           logits_1)\n          layer_collection_manual.register_fully_connected((w_2, b_2), x,\n                                                           logits_2)\n\n      gs.register_layers(layer_collection,\n                         tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n\n  def test_filter_user_registered_records(self):\n    \"\"\"Matches containing already registered variables should be removed.\"\"\"\n    with tf.Graph().as_default():\n      tensor_dict, record_list_dict = _build_mock_records()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_fully_connected(\n          params=(tensor_dict['w'], tensor_dict['b_1']),\n          inputs=tensor_dict['x'],\n          outputs=tensor_dict['pre_bias_0'])\n      user_registered_variables = set()\n      for params in layer_collection.fisher_blocks.keys():\n        for variable in gs.ensure_sequence(params):\n          user_registered_variables.add(variable)\n      filtered_record_list_dict = gs.filter_user_registered_records(\n          record_list_dict, user_registered_variables)\n      expected_keys = [tensor_dict['b_0']]\n      self.assertDictEqual(filtered_record_list_dict,\n                           {k: record_list_dict[k]\n                            for k in expected_keys})\n\n  def test_filter_grouped_variable_records(self):\n    \"\"\"Matches violating specified parameter groupings should be removed.\"\"\"\n    with tf.Graph().as_default():\n      tensor_dict, record_list_dict = _build_mock_records()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.define_linked_parameters(params=tensor_dict['w'])\n      filtered_record_list_dict = gs.filter_grouped_variable_records(\n          layer_collection, record_list_dict)\n      expected_keys = [tensor_dict['w'], tensor_dict['b_0']]\n      self.assertDictEqual(filtered_record_list_dict,\n                           {k: record_list_dict[k]\n                            for k in expected_keys})\n\n    with tf.Graph().as_default():\n      tensor_dict, record_list_dict = _build_mock_records()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.define_linked_parameters(\n          params=(tensor_dict['w'], tensor_dict['b_0']))\n      filtered_record_list_dict = gs.filter_grouped_variable_records(\n          layer_collection, record_list_dict)\n      expected_keys = [(tensor_dict['w'], tensor_dict['b_0'])]\n      self.assertDictEqual(filtered_record_list_dict,\n                           {k: record_list_dict[k]\n                            for k in expected_keys})\n\n  def test_filter_subgraph_records(self):\n    \"\"\"Matches that are strict subgraphs of other matches should be removed.\"\"\"\n    with tf.Graph().as_default():\n      tensor_dict, record_list_dict = _build_mock_records()\n      filtered_record_list_dict = gs.filter_subgraph_records(record_list_dict)\n      expected_keys = [(tensor_dict['w'], tensor_dict['b_0']),\n                       (tensor_dict['w'], tensor_dict['b_1'])]\n      self.assertDictEqual(filtered_record_list_dict,\n                           {k: record_list_dict[k]\n                            for k in expected_keys})\n\n  def test_rnn_multi(self):\n    \"\"\"Test automatic registration on a static RNN.\n\n    The model tested here is designed for MNIST classification. To classify\n    images using a recurrent neural network, we consider every image row as a\n    sequence of pixels. Because MNIST image shape is 28*28px, we will then\n    handle 28 sequences of 28 steps for every sample.\n    \"\"\"\n    with tf.Graph().as_default():\n      dtype = tf.float32\n      n_input = 28  # MNIST data input (img shape: 28*28)\n      n_timesteps = 28  # timesteps\n      n_hidden = 128  # hidden layer num of features\n      n_classes = 10  # MNIST total classes (0-9 digits)\n\n      x = tf.placeholder(dtype, [None, n_timesteps, n_input])\n      y = tf.placeholder(tf.int32, [None])\n      x_unstack = tf.unstack(x, n_timesteps, 1)\n\n      w_input = tf.get_variable(\n          'w_input', shape=[n_input, n_hidden], dtype=dtype)\n      b_input = tf.get_variable('b_input', shape=[n_hidden], dtype=dtype)\n\n      w_recurrent = tf.get_variable(\n          'w_recurrent', shape=[n_hidden, n_hidden], dtype=dtype)\n      b_recurrent = tf.get_variable(\n          'b_recurrent', shape=[n_hidden], dtype=dtype)\n\n      w_output = tf.get_variable(\n          'w_output', shape=[n_hidden, n_classes], dtype=dtype)\n      b_output = tf.get_variable('b_output', shape=[n_classes], dtype=dtype)\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_auto = lc.LayerCollection()\n\n      a = tf.zeros(tf.convert_to_tensor([tf.shape(x_unstack[0])[0], n_hidden]),\n                   dtype=dtype)\n\n      # Here 'a' are the activations, 's' the pre-activations.\n      a_list = [a]\n      s_input_list = []\n      s_recurrent_list = []\n      s_list = []\n      s_out_list = []\n      cost = 0.0\n\n      for i in range(len(x_unstack)):\n        input_ = x_unstack[i]\n\n        s_in = tf.matmul(input_, w_input) + b_input\n        s_rec = tf.matmul(a, w_recurrent) + b_recurrent\n        s = s_in + s_rec\n\n        s_input_list.append(s_in)\n        s_recurrent_list.append(s_rec)\n        s_list.append(s)\n\n        a = tf.tanh(s)\n        a_list.append(a)\n\n        s_out = tf.matmul(a, w_output) + b_output\n        s_out_list.append(s_out)\n\n        if i == len(x_unstack) - 1:\n          labels = y\n        else:\n          labels = tf.zeros([tf.shape(y)[0]], dtype=tf.int32)\n\n        cost += tf.reduce_mean(\n            tf.nn.sparse_softmax_cross_entropy_with_logits(\n                logits=s_out, labels=labels))\n\n        layer_collection_manual.register_categorical_predictive_distribution(\n            s_out)\n        layer_collection_auto.register_categorical_predictive_distribution(\n            s_out)\n\n      layer_collection_manual.register_fully_connected_multi(\n          (w_input, b_input), x_unstack, s_input_list)\n      layer_collection_manual.register_fully_connected_multi(\n          (w_recurrent, b_recurrent), a_list[:-1], s_recurrent_list)\n      layer_collection_manual.register_fully_connected_multi(\n          (w_output, b_output), a_list[1:], s_out_list)\n\n      gs.register_layers(layer_collection_auto,\n                         tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n      assert_fisher_blocks_match(self, layer_collection_manual,\n                                 layer_collection_auto)\n\n  def test_graph_search_match_fail(self):\n    \"\"\"Tests graph search with linked bias tensors.\n\n    In this code snippet two non adjacent bias tensors are linked together.\n    There is no fisher block in kfac that matches this configuration, so the\n    biases should not be registered.\n    \"\"\"\n    with tf.Graph().as_default():\n      tensor_dict = _build_model()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection.register_squared_error_loss(tensor_dict['out_1'])\n\n      # TODO(b/69055612): remove this manual registration once layer_collection\n      # implements register_fully_connected_multi.\n      layer_collection.register_fully_connected(\n          tensor_dict['w'], tensor_dict['x'], tensor_dict['pre_bias_0'])\n      layer_collection.define_linked_parameters((tensor_dict['b_0'],\n                                                 tensor_dict['b_1']))\n\n      with self.assertRaises(ValueError) as cm:\n        gs.register_layers(layer_collection,\n                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n      self.assertIn('in linked group', str(cm.exception))\n      self.assertIn('was not matched', str(cm.exception))\n      self.assertIn(\n          str(frozenset([tensor_dict['b_0'], tensor_dict['b_1']])),\n          str(cm.exception))\n\n  def test_specify_approximation(self):\n    \"\"\"Test specifying approximations.\n\n    If linked parameters are identified along with an approximation, then\n    that approximation should be used when registering those parameters.\n    \"\"\"\n    with tf.Graph().as_default():\n      w_0 = tf.get_variable('w_0', [10, 10])\n      w_1 = tf.get_variable('w_1', [10, 10])\n\n      b_0 = tf.get_variable('b_0', [10])\n      b_1 = tf.get_variable('b_1', [10])\n\n      x_0 = tf.placeholder(tf.float32, shape=(32, 10))\n      x_1 = tf.placeholder(tf.float32, shape=(32, 10))\n\n      pre_bias_0 = tf.matmul(x_0, w_0)\n      pre_bias_1 = tf.matmul(x_1, w_1)\n\n      out_0 = pre_bias_0 + b_0  # pylint: disable=unused-variable\n      out_1 = pre_bias_1 + b_1  # pylint: disable=unused-variable\n\n      # Group variables as affine layers.\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(out_0)\n      layer_collection.register_squared_error_loss(out_1)\n\n      layer_collection.define_linked_parameters(\n          (w_0, b_0), approximation=lc.APPROX_KRONECKER_NAME)\n      layer_collection.define_linked_parameters(\n          (w_1, b_1), approximation=lc.APPROX_DIAGONAL_NAME)\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=32)\n      self.assertIsInstance(layer_collection.fisher_blocks[(w_0, b_0)],\n                            fb.FullyConnectedKFACBasicFB)\n      self.assertIsInstance(layer_collection.fisher_blocks[(w_1, b_1)],\n                            fb.FullyConnectedDiagonalFB)\n\n      # Group variables as linear layers and generic parameters.\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(out_0)\n      layer_collection.register_squared_error_loss(out_1)\n\n      layer_collection.define_linked_parameters(\n          w_0, approximation=lc.APPROX_DIAGONAL_NAME)\n      layer_collection.define_linked_parameters(\n          b_0, approximation=lc.APPROX_DIAGONAL_NAME)\n      layer_collection.define_linked_parameters(\n          w_1, approximation=lc.APPROX_KRONECKER_NAME)\n      layer_collection.define_linked_parameters(\n          b_1, approximation=lc.APPROX_FULL_NAME)\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=32)\n      self.assertIsInstance(layer_collection.fisher_blocks[w_0],\n                            fb.FullyConnectedDiagonalFB)\n      self.assertIsInstance(layer_collection.fisher_blocks[b_0],\n                            fb.NaiveDiagonalFB)\n      self.assertIsInstance(layer_collection.fisher_blocks[w_1],\n                            fb.FullyConnectedKFACBasicFB)\n      self.assertIsInstance(layer_collection.fisher_blocks[b_1], fb.FullFB)\n\n  def test_specify_approximation_shared_parameters(self):\n    \"\"\"Test specifying approximations with layers containing shared parameters.\n\n    If linked parameters are identified along with an approximation, then\n    that approximation should be used when registering those parameters.\n    \"\"\"\n    with tf.Graph().as_default():\n      tensor_dict = _build_model()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection.register_squared_error_loss(tensor_dict['out_1'])\n\n      layer_collection.define_linked_parameters(\n          tensor_dict['w'], approximation=lc.APPROX_KRONECKER_INDEP_NAME)\n      layer_collection.define_linked_parameters(\n          tensor_dict['b_0'], approximation=lc.APPROX_DIAGONAL_NAME)\n      layer_collection.define_linked_parameters(\n          tensor_dict['b_1'], approximation=lc.APPROX_FULL_NAME)\n\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=1)\n\n      self.assertIsInstance(layer_collection.fisher_blocks[tensor_dict['w']],\n                            fb.FullyConnectedMultiIndepFB)\n      self.assertIsInstance(\n          layer_collection.fisher_blocks[tensor_dict['b_0']],\n          fb.NaiveDiagonalFB)\n      self.assertIsInstance(\n          layer_collection.fisher_blocks[tensor_dict['b_1']], fb.FullFB)\n\n  def test_tied_weights_untied_bias_registered_weights(self):\n    \"\"\"Tests that graph search produces right solution on toy model.\"\"\"\n    with tf.Graph().as_default():\n      tensor_dict = _build_model()\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_manual.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection_manual.register_squared_error_loss(tensor_dict['out_1'])\n\n      layer_collection_manual.register_fully_connected_multi(\n          tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']),\n          (tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1']))\n      layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=1)\n      layer_collection_manual.register_generic(tensor_dict['b_1'], batch_size=1)\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection.register_squared_error_loss(tensor_dict['out_1'])\n\n      layer_collection.define_linked_parameters((tensor_dict['w']))\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=1)\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n\n  def test_tied_weights_untied_bias_registered_affine(self):\n    \"\"\"Test registering linked variables.\n\n    Registering (w, b_1) as linked variables should not raise an error, since\n    the matches with parameters (w) and (w, b_0) will be filtered out.\n    \"\"\"\n    with tf.Graph().as_default():\n      tensor_dict = _build_model()\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_manual.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection_manual.register_squared_error_loss(tensor_dict['out_1'])\n\n      layer_collection_manual.register_fully_connected(\n          params=(tensor_dict['w'], tensor_dict['b_1']),\n          inputs=tensor_dict['y'],\n          outputs=tensor_dict['out_1'])\n      layer_collection_manual.register_generic(\n          tensor_dict['b_0'], batch_size=32)\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection.register_squared_error_loss(tensor_dict['out_1'])\n\n      layer_collection.define_linked_parameters((tensor_dict['w'],\n                                                 tensor_dict['b_1']))\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=32)\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n\n  def test_tied_weights_untied_bias(self):\n    \"\"\"Tests that ambiguity in graph raises an error.\n\n    Graph search will find several possible registrations containing w including\n    (w, b_1) & (w, b_2). Without any instructions in form of linked tensors or\n    manual registration it defaults to registering an error and suggesting that\n    the user register (w) as a linked tensor.\n    \"\"\"\n    with tf.Graph().as_default():\n      tensor_dict = _build_model()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection.register_squared_error_loss(tensor_dict['out_1'])\n\n      with self.assertRaises(gs.AmbiguousRegistrationError):\n        gs.register_layers(layer_collection,\n                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n  def test_tied_weights_untied_bias_registered_bias(self):\n    \"\"\"Tests that ambiguity in graph raises value error.\n\n    Graph search will find several possible registrations for tensors.\n    In this registering b_1 as a linked variable will result in an error\n    because there will remain an ambiguity on the other branch of the graph.\n    \"\"\"\n    with tf.Graph().as_default():\n      tensor_dict = _build_model()\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(tensor_dict['out_0'])\n      layer_collection.register_squared_error_loss(tensor_dict['out_1'])\n\n      layer_collection.define_linked_parameters((tensor_dict['b_1']))\n\n      with self.assertRaises(gs.AmbiguousRegistrationError):\n        gs.register_layers(layer_collection,\n                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n  def test_multi_time_batch_fold(self):\n    \"\"\"Test that graph search provides desired registration on toy model.\n\n      In this toy example we apply the same linear layer to two different\n      inputs. This tests whether graph search can correctly group them. Also\n      tests whether batch/time folded is correctly registered as fully\n      connected multi fisher blocks.\n    \"\"\"\n    with tf.Graph().as_default():\n      w = tf.get_variable('W', [10, 10])\n      b_0 = tf.get_variable('b_0', [\n          10,\n      ])\n      x = tf.placeholder(tf.float32, shape=(32, 10))\n      y = tf.placeholder(tf.float32, shape=(32, 10))\n\n      out_0 = tf.matmul(x, w) + b_0\n      out_1 = tf.matmul(y, w) + b_0\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_manual.register_squared_error_loss(out_0)\n      layer_collection_manual.register_squared_error_loss(out_1)\n\n      layer_collection_manual.register_fully_connected_multi(\n          (w, b_0), (x, y), (out_0, out_1), num_uses=2)\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(out_0)\n      layer_collection.register_squared_error_loss(out_1)\n\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=16)\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n\n  def test_multiple_weights(self):\n    \"\"\"Test that graph search provides desired registration on toy model.\n\n    In this toy example we apply the same linear layer to two different inputs.\n    This tests whether graph search can correctly group them.\n    \"\"\"\n    with tf.Graph().as_default():\n      w = tf.get_variable('W', [10, 10])\n      b_0 = tf.get_variable('b_0', [\n          10,\n      ])\n      x = tf.placeholder(tf.float32, shape=(32, 10))\n      y = tf.placeholder(tf.float32, shape=(32, 10))\n\n      out_0 = tf.matmul(x, w) + b_0\n      out_1 = tf.matmul(y, w) + b_0\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_manual.register_fully_connected_multi((w, b_0), (x, y),\n                                                             (out_0, out_1))\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(out_0)\n      layer_collection.register_squared_error_loss(out_1)\n\n      gs.register_layers(layer_collection,\n                         tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n\n  def test_subset_weights_manual_registration(self):\n    \"\"\"Test that graph search provides desired registration on toy model.\n\n    In this toy example we apply the same matmul op to two different inputs\n    followed by adding a bias to one of the inputs. This tests whether graph\n    search can correctly group them.\n    \"\"\"\n    with tf.Graph().as_default():\n      w = tf.get_variable('W', [10, 10])\n      b_0 = tf.get_variable('b_0', [10,])\n      x = tf.placeholder(tf.float32, shape=(32, 10))\n      y = tf.placeholder(tf.float32, shape=(32, 10))\n\n      out_n1 = tf.matmul(x, w)\n      out_0 = out_n1 + b_0\n      out_1 = tf.matmul(y, w)\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_manual.register_fully_connected_multi(\n          w, (x, y), (out_n1, out_1))\n      layer_collection_manual.register_generic(b_0, batch_size=1)\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(out_0)\n      layer_collection.register_squared_error_loss(out_1)\n\n      layer_collection.define_linked_parameters(w)\n      gs.register_layers(\n          layer_collection,\n          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),\n          batch_size=1)\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n\n  def mixed_usage_test(self):\n    \"\"\"Tests that graph search raises error on mixed types usage for tensors.\n\n    Tensors can be reused in various locations in the tensorflow graph. This\n    occurs regularly in the case of recurrent models or models with parallel\n    graphs. However the tensors must be used for the same operation in each\n    location or graph search should raise an error.\n    \"\"\"\n    with tf.Graph().as_default():\n      w = tf.get_variable('W', [10, 10])\n      x = tf.placeholder(tf.float32, shape=(32, 10))\n      y = tf.placeholder(tf.float32, shape=(32, 10, 10))\n\n      out_0 = tf.matmul(x, w)  # pylint: disable=unused-variable\n      out_1 = y + w  # pylint: disable=unused-variable\n\n      layer_collection = lc.LayerCollection()\n\n      with self.assertRaises(ValueError) as cm:\n        gs.register_layers(layer_collection,\n                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))\n\n      self.assertIn('mixed record types', str(cm.exception))\n\n  def test_resource_variable(self):\n    \"\"\"Ensures that ResourceVariables can be matched.\"\"\"\n    with tf.Graph().as_default():\n      w = tf.get_variable('w', [10, 10], use_resource=True)\n      b = tf.get_variable('b', [10], use_resource=True)\n      x = tf.placeholder(tf.float32, shape=(32, 10))\n      out_0 = tf.matmul(x, w) + b\n\n      layer_collection = lc.LayerCollection()\n      layer_collection.register_squared_error_loss(out_0)\n\n      gs.register_layers(layer_collection, [w, b])\n\n      layer_collection_manual = lc.LayerCollection()\n      layer_collection_manual.register_squared_error_loss(out_0)\n      layer_collection_manual.register_fully_connected((w, b), x, out_0)\n\n      assert_fisher_blocks_match(self, layer_collection,\n                                 layer_collection_manual)\n      self.assertEqual(1, len(layer_collection.get_blocks()))\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/keras_callbacks_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for keras/callbacks.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.keras import callbacks\nfrom kfac.python.keras import optimizers\n\nlayers = tf.keras.layers\n_SEED = 1234\n\n\nclass HyperParamTracker(tf.keras.callbacks.Callback):\n  EPOCH, BATCH = range(2)\n\n  def __init__(self, hyper, record_list, frequency):\n    self.hyper = hyper\n    self.record_list = record_list\n    self.frequency = frequency\n\n  def on_batch_end(self, batch, logs=None):\n    if self.frequency == HyperParamTracker.BATCH:\n      val = tf.keras.backend.get_value(getattr(self.model.optimizer,\n                                               self.hyper))\n      self.record_list.append(val)\n\n  def on_epoch_end(self, epoch, logs=None):\n    if self.frequency == HyperParamTracker.EPOCH:\n      val = tf.keras.backend.get_value(getattr(self.model.optimizer,\n                                               self.hyper))\n      self.record_list.append(val)\n\n\nclass CallbacksTest(parameterized.TestCase, tf.test.TestCase):\n\n  def __init__(self, *args, **kwargs):\n    super(CallbacksTest, self).__init__(*args, **kwargs)\n    self.batch_size = 16\n    self.num_steps = 20\n    self.data = np.random.random((self.batch_size*self.num_steps))\n    self.labels = np.random.random((self.batch_size*self.num_steps))\n\n  def setUp(self):\n    super(CallbacksTest, self).setUp()\n    self.model = tf.keras.Sequential([layers.Dense(1, input_shape=(1,))])\n    tf.random.set_random_seed(_SEED)\n\n  def testPolynomialDecayValues(self):\n    init_value = 0.01\n    final_value = 0.0002\n    power = 0.6\n    num_decay_steps = 11\n    num_delay_steps = 3\n    opt = tf.keras.optimizers.Adam(learning_rate=init_value)\n    self.model.compile(opt, 'mse')\n    lr_list = []\n    cbs = [\n        callbacks.PolynomialDecay(hyperparameter='learning_rate',\n                                  init_value=init_value,\n                                  final_value=final_value,\n                                  power=power,\n                                  num_decay_steps=num_decay_steps,\n                                  num_delay_steps=num_delay_steps,\n                                  verbose=1),\n        HyperParamTracker('learning_rate', lr_list, HyperParamTracker.BATCH)\n    ]\n    self.model.fit(\n        self.data, self.labels, batch_size=self.batch_size, callbacks=cbs)\n    expected_list = [init_value] * num_delay_steps + [\n        (init_value - final_value) *\n        (1 - min(i, num_decay_steps) / float(num_decay_steps)) ** power +\n        final_value for i in range(self.num_steps - num_delay_steps)\n    ]\n    self.assertAllClose(lr_list, expected_list)\n\n  def testExponentialDampingValuesWithDecayRate(self):\n    init_value = 0.01\n    decay_rate = 0.3\n    num_decay_steps = 4\n    num_delay_steps = 3\n    opt = optimizers.Kfac(\n        learning_rate=0.01, damping=init_value, model=self.model, loss='mse')\n    self.model.compile(opt, 'mse')\n    damping_list = []\n    cbs = [\n        callbacks.ExponentialDecay(hyperparameter='damping',\n                                   init_value=init_value,\n                                   decay_rate=decay_rate,\n                                   num_decay_steps=num_decay_steps,\n                                   num_delay_steps=num_delay_steps,\n                                   verbose=1),\n        HyperParamTracker('damping', damping_list, HyperParamTracker.BATCH)\n    ]\n    self.model.fit(\n        self.data, self.labels, batch_size=self.batch_size, callbacks=cbs)\n\n    expected_list = [init_value] * num_delay_steps + [\n        init_value * decay_rate ** min(i, num_decay_steps)\n        for i in range(self.num_steps - num_delay_steps)\n    ]\n    self.assertAllClose(damping_list, expected_list)\n\n  def testExponentialDampingValuesWithFinalValue(self):\n    init_value = 0.01\n    final_value = 0.0001\n    num_decay_steps = 4\n    num_delay_steps = 3\n    opt = optimizers.Kfac(\n        learning_rate=0.01, damping=init_value, model=self.model, loss='mse')\n    self.model.compile(opt, 'mse')\n    damping_list = []\n    cbs = [\n        callbacks.ExponentialDecay(hyperparameter='damping',\n                                   init_value=init_value,\n                                   final_value=final_value,\n                                   num_decay_steps=num_decay_steps,\n                                   num_delay_steps=num_delay_steps,\n                                   verbose=1),\n        HyperParamTracker('damping', damping_list, HyperParamTracker.BATCH)\n    ]\n    self.model.fit(\n        self.data, self.labels, batch_size=self.batch_size, callbacks=cbs)\n\n    expected_list = [init_value] * num_delay_steps + [\n        init_value * (final_value/init_value) **\n        (min(i, num_decay_steps)*1./num_decay_steps)\n        for i in range(self.num_steps - num_delay_steps)\n    ]\n    self.assertAllClose(damping_list, expected_list)\n    self.assertNear(damping_list[-1], final_value, err=1e-5)\n\n  def testExponentialDampingValuesWithFinalValueAndRate(self):\n    init_value = 0.01\n    final_value = 0.0001\n    decay_rate = 0.6\n    num_delay_steps = 3\n    opt = optimizers.Kfac(\n        learning_rate=0.01, damping=init_value, model=self.model, loss='mse')\n    self.model.compile(opt, 'mse')\n    damping_list = []\n    cbs = [\n        callbacks.ExponentialDecay(hyperparameter='damping',\n                                   init_value=init_value,\n                                   final_value=final_value,\n                                   decay_rate=decay_rate,\n                                   num_delay_steps=num_delay_steps,\n                                   verbose=1),\n        HyperParamTracker('damping', damping_list, HyperParamTracker.BATCH)\n    ]\n    self.model.fit(\n        self.data, self.labels, batch_size=self.batch_size, callbacks=cbs)\n\n    expected_list = [init_value] * num_delay_steps + [\n        max((init_value * decay_rate ** i), final_value)\n        for i in range(self.num_steps - num_delay_steps)\n    ]\n    self.assertAllClose(damping_list, expected_list)\n    self.assertNear(damping_list[-1], final_value, err=1e-5)\n\n  @parameterized.named_parameters(\n      ('_Exponential', 'damping',\n       callbacks.ExponentialDecay(hyperparameter='damping',\n                                  init_value=0.01,\n                                  decay_rate=0.3,\n                                  num_decay_steps=30)),\n      ('_Polynomial', 'learning_rate',\n       callbacks.PolynomialDecay(hyperparameter='learning_rate',\n                                 init_value=0.001,\n                                 final_value=0.002,\n                                 power=0.6,\n                                 num_decay_steps=30)))\n  def testTrainHistory(self, hyper, callback):\n    opt = optimizers.Kfac(learning_rate=0.001, damping=0.01,\n                          model=self.model, loss='mse', num_burnin_steps=5)\n    self.model.compile(opt, 'mse')\n    lst = []\n    cbs = [callback, HyperParamTracker(hyper, lst, HyperParamTracker.EPOCH)]\n    hist = self.model.fit(self.data, self.labels,\n                          batch_size=self.batch_size, epochs=3, callbacks=cbs)\n    self.assertAllClose(lst, hist.history[hyper])\n\n  def testDampingDecayFailsWithNoDamping(self):\n    with self.assertRaisesRegex(ValueError, '.*must have a \"damping\".*'):\n      self.model.compile('adam', 'mse')\n      cb = callbacks.ExponentialDecay(hyperparameter='damping',\n                                      init_value=0.01,\n                                      decay_rate=0.3,\n                                      num_decay_steps=4)\n      self.model.fit(self.data, self.data, callbacks=[cb])\n\n  def testExponentialDampingFailsNoRateOrFinalValue(self):\n    with self.assertRaisesRegex(ValueError, '.*must specify exactly two of.*'):\n      callbacks.ExponentialDecay(hyperparameter='damping',\n                                 init_value=0.01)\n\n  def testExponentialDampingFailsWithAllOptionals(self):\n    with self.assertRaisesRegex(ValueError, '.*must specify exactly two of.*'):\n      callbacks.ExponentialDecay(hyperparameter='learning_rate',\n                                 init_value=0.01,\n                                 final_value=0.001,\n                                 decay_rate=0.99,\n                                 num_decay_steps=50)\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/keras_optimizers_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for keras/optimizers.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport json\nfrom absl.testing import parameterized\nfrom tensorflow.python.keras import backend\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.util import serialization\nfrom kfac.python.keras import optimizers\nfrom kfac.python.keras import utils\n\nlayers = tf.keras.layers\nlosses = tf.keras.losses\n_SEED = 1234\n\n\n# TODO(b/135916953): Use TensorFlow test_utils instead of below helpers.\ndef _get_synthetic_mnist_dataset(train_size=64, test_size=16):\n  num_classes = 10\n  img_rows, img_cols = 28, 28\n\n  rng = np.random.RandomState(_SEED)\n  num_examples = train_size + test_size\n  images = rng.rand(num_examples, img_rows * img_cols).astype(np.float32)\n  images = np.reshape(images, [num_examples, img_rows, img_cols, 1])\n  labels = rng.randint(num_classes, size=num_examples)\n  one_hot_labels = np.eye(num_classes)[labels].astype(np.float32)\n\n  return ((images[:train_size], one_hot_labels[:train_size]),\n          (images[train_size:], one_hot_labels[train_size:]))\n\n\ndef _get_synthetic_mnist_train_tensors(\n    train_size=64, batch_size=10, drop_remainder=False):\n  (x_train, y_train), _ = _get_synthetic_mnist_dataset(train_size=train_size)\n  dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n  dataset = dataset.repeat().batch(batch_size, drop_remainder=drop_remainder)\n  return dataset.make_one_shot_iterator().get_next()\n\n\ndef _generate_target_fn(num_examples):\n  \"\"\"Generated a random 2d target function for regression.\n\n  Args:\n    num_examples: The number of evenly spaced examples along the function to\n      generate.\n\n  Returns:\n    A tuple of the x tensor and the y tensor for the generated function.\n  \"\"\"\n  inds = np.arange(num_examples)\n  x = np.sort(np.random.rand(num_examples) - 0.5)\n  x = np.expand_dims(x, axis=1)\n  y = np.transpose(x)\n  dist = np.square(x - y)  # Should be scipy cdist(x, x, metric='sqeuclidean')\n  k = np.exp(-dist / 0.01)\n  k += np.eye(k.shape[0]) * 1e-6\n  l = np.linalg.cholesky(k)\n  random_y = np.random.randn(x.shape[0], 1)\n  y = np.dot(l, random_y) + np.random.randn(x.shape[0], 1) * 1e-1\n  return x[inds, :], y[inds, :]\n\n\ndef _generate_regression_data(num_eg, num_train_eg):\n  x_all, y_all = _generate_target_fn(num_eg)\n  x_all = x_all.astype(np.float32)\n  y_all = y_all.astype(np.float32)\n\n  inds = np.arange(num_eg)\n  np.random.shuffle(inds)\n  x_train = x_all[inds[:num_train_eg]]\n  y_train = y_all[inds[:num_train_eg]]\n\n  x_test = x_all[inds[num_train_eg:]]\n  y_test = y_all[inds[num_train_eg:]]\n\n  return (x_train, y_train), (x_test, y_test)\n\n\ndef _simple_mlp():\n  return tf.keras.Sequential([\n      layers.Dense(32, input_shape=(1,), activation='tanh'),\n      layers.Dense(32, activation='tanh'),\n      layers.Dense(1)\n  ])\n\n\ndef _mnist_model(use_bias=True, use_separate_activation=True):\n  \"\"\"A complex architecture to test the variable registration.\n\n  This model is not intended to be a \"good\" mnist classifier.\n  It uses Lambda layers, concats, and separate branches to test effectively.\n\n  Args:\n    use_bias: boolean. Whether all the layers use a bias term or not.\n    use_separate_activation: boolean. Whether the layers have the activation\n      within the layer or use a separate activation layer.\n\n  Returns:\n    A Keras model containing the mnist classifier.\n  \"\"\"\n  activation = 'linear' if use_separate_activation else 'relu'\n  output_activation = 'linear' if use_separate_activation else 'softmax'\n\n  inp = layers.Input(shape=(28, 28, 1))\n\n  branch1 = layers.Lambda(lambda x: tf.squeeze(x, -1))(inp)\n  branch1 = layers.Conv1D(3, kernel_size=7, activation=activation,\n                          use_bias=use_bias)(branch1)\n  if use_separate_activation:\n    branch1 = layers.Activation('relu')(branch1)\n  branch1 = layers.GlobalMaxPool1D()(branch1)\n\n  branch2 = layers.Conv2D(16, kernel_size=(3, 3), activation=activation,\n                          use_bias=use_bias)(inp)\n  if use_separate_activation:\n    branch2 = layers.Activation('relu')(branch2)\n  branch2 = layers.MaxPooling2D(pool_size=(4, 4))(branch2)\n  branch2 = layers.Flatten()(branch2)\n  branch2 = layers.Dense(20, use_bias=use_bias)(branch2)\n  if use_separate_activation:\n    branch2 = layers.Activation('relu')(branch2)\n\n  out = layers.concatenate([branch1, branch2])\n  out = layers.Dense(10, use_bias=use_bias, activation=output_activation)(out)\n  if use_separate_activation:\n    out = layers.Activation('softmax')(out)\n\n  return tf.keras.Model(inputs=inp, outputs=out)\n\n\ndef _train_model(data,\n                 model,\n                 loss,\n                 lr=0.001,\n                 damping=0.001,\n                 batch_size=32,\n                 epochs=1,\n                 loss_weights=None):\n  \"\"\"Compiles and fits model to data and returns trainging results.\n\n  Args:\n    data: Tuple of numpy arrays shaped ((x_train, y_train), (x_test, y_test)).\n    model: Uncompiled Keras model with inputs/output shapes matching the data.\n    loss: tf.keras.losses loss function or serialized (string) loss function.\n    lr: Learning rate for optimizer.\n    damping: Damping parameter for KFAC.\n    batch_size: Batch size used for training.\n    epochs: Number of training epochs.\n    loss_weights: List of weights or dict mapping layer names to loss function\n      weight.\n\n  Returns:\n    A History object. Calling History.history gives you a dictionary with\n    training and validation results.\n  \"\"\"\n  (x_train, y_train), valid_data = data\n  opt = optimizers.Kfac(learning_rate=lr, damping=damping, model=model,\n                            loss=loss, loss_weights=loss_weights)\n  model.compile(opt, loss, loss_weights=loss_weights)\n\n  return model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,\n                   validation_data=valid_data, verbose=0)\n\n\nclass KfacOptimizerTest(parameterized.TestCase, tf.test.TestCase):\n\n  def __init__(self, *args, **kwargs):\n    super(KfacOptimizerTest, self).__init__(*args, **kwargs)\n    self._mnist_data = _get_synthetic_mnist_dataset()\n\n  def setUp(self):\n    super(KfacOptimizerTest, self).setUp()\n    tf.random.set_random_seed(_SEED)\n    np.random.seed(_SEED)\n\n  def testFunctionalInstantiation(self):\n    inputs = layers.Input(shape=(3,))\n    x = layers.Dense(4, activation=tf.nn.relu)(inputs)\n    outputs = layers.Dense(5, activation=tf.nn.softmax)(x)\n    model = tf.keras.Model(inputs=inputs, outputs=outputs)\n    optimizers.Kfac(learning_rate=0.002, damping=0.04,\n                        model=model, loss='binary_crossentropy')\n\n  def testSequentialInstantiation(self):\n    model = tf.keras.Sequential([\n        layers.Conv2D(7, (3, 3), input_shape=(28, 28, 3)),\n        layers.Activation('relu'),\n        layers.Conv2D(13, (3, 3), activation='relu'),\n        layers.GlobalMaxPool2D(),\n        layers.Activation('softmax')\n    ])\n    optimizers.Kfac(learning_rate=0.03, damping=0.00007,\n                        model=model, loss='binary_crossentropy')\n\n  def testInstantiationWithLayerCollection(self):\n    model = _simple_mlp()\n    lc = utils.get_layer_collection(model, 'mse')\n    opt = optimizers.Kfac(\n        learning_rate=0.1, damping=0.2, layer_collection=lc)\n    model.compile(optimizer=opt, loss='mse')\n    opt.get_updates(model.total_loss, model.trainable_weights)\n\n  def testRNNFails(self):\n    model = tf.keras.Sequential()\n    model.add(layers.Embedding(43, 128))\n    model.add(layers.LSTM(128, dropout=0.2, recurrent_dropout=0.2))\n    model.add(layers.Dense(1, activation='sigmoid'))\n    opt = optimizers.Kfac(learning_rate=0.003, damping=0.003,\n                              model=model, loss='binary_crossentropy')\n    with self.assertRaisesRegex(ValueError,\n                                '.*lstm.* has more than one parent tensor.$'):\n      opt._create_optimizer()\n\n  @parameterized.named_parameters(('BiasCombinedActivation', True, True),\n                                  ('BiasSeparateActivation', True, False),\n                                  ('NoBiasCombinedActivation', False, True),\n                                  ('NoBiasSeparateActivation', False, False))\n  def testBiasAndActivations(self, use_bias, use_separate_activation):\n    model = _mnist_model(use_bias=use_bias,\n                         use_separate_activation=use_separate_activation)\n    _train_model(self._mnist_data, model, 'categorical_crossentropy')\n\n  def testRegression(self):\n    hist = _train_model(\n        _generate_regression_data(200, 150), _simple_mlp(), 'mse', epochs=5)\n    val_loss = hist.history['val_loss']\n    self.assertGreater(val_loss[0], val_loss[-1])\n\n  def testClipNormFails(self):\n    with self.assertRaises(ValueError):\n      optimizers.Kfac(learning_rate=0.001, damping=0.001,\n                          model=_simple_mlp(), loss='mse', clipnorm=0.1)\n\n  def testClipValueFails(self):\n    with self.assertRaises(ValueError):\n      optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                          model=_simple_mlp(), loss='mse', clipvalue=0.1)\n\n  def testLossTensor(self):\n    loss_tensor = tf.convert_to_tensor(2.0)\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                              model=_simple_mlp(), loss='mse',\n                              loss_tensor=loss_tensor)\n    self.assertEqual(opt.optimizer._loss_tensor, loss_tensor)\n\n  def testArgsKwargs(self):\n    \"\"\"Test if kwargs are correctly forwarded to tensorflow_kfac.\"\"\"\n    kwargs = {\n        'learning_rate': 3.0,\n        'damping': 5.0,\n        'momentum': 7.0,\n        'min_damping': 9.0,\n        'num_burnin_steps': 11,\n        'invert_every': 13,\n        'fisher_approx': {\n            layers.Dense: 'kron_in_diag',\n            'dense_1': 'kron_both_diag'\n        },\n    }\n    model = _simple_mlp()\n    opt = optimizers.Kfac(model=model, loss='mse', **kwargs)\n    self.assertEqual(opt.optimizer._min_damping, kwargs['min_damping'])\n    self.assertEqual(opt.optimizer._num_burnin_steps,\n                     kwargs['num_burnin_steps'])\n    self.assertEqual(opt.optimizer._invert_every, kwargs['invert_every'])\n\n    fisher_block_0 = opt.optimizer.layers.fisher_blocks[model.layers[0].weights]\n    self.assertTrue(fisher_block_0._diagonal_approx_for_input)\n    self.assertFalse(fisher_block_0._diagonal_approx_for_output)\n    fisher_block_1 = opt.optimizer.layers.fisher_blocks[model.layers[1].weights]\n    self.assertTrue(fisher_block_1._diagonal_approx_for_input)\n    self.assertTrue(fisher_block_1._diagonal_approx_for_output)\n\n    with tf.Session() as sess:\n      # In Keras, typically you do not use sessions directly. When you use a\n      # Keras component, the required variables are initialized for you because\n      # they are tracked. Here, we explicitly run the variables in a session so\n      # they must be initialized.\n      sess.run(tf.global_variables_initializer())\n      self.assertEqual(sess.run(opt.optimizer.momentum), kwargs['momentum'])\n      self.assertEqual(sess.run(opt.optimizer.learning_rate),\n                       kwargs['learning_rate'])\n      self.assertEqual(sess.run(opt.optimizer.damping), kwargs['damping'])\n\n  def testConfig(self):\n    fisher_approx = {layers.Dense: 'kron_in_diag', 'dense_1': 'kron_both_diag'}\n    kwargs = {\n        'loss': 'mse',\n        'momentum': 7.0,\n        'num_burnin_steps': 11.0,\n        'min_damping': 9.0,\n        'invert_every': 13,\n        'fisher_approx': fisher_approx,\n        'seed': 12,\n    }\n    opt = optimizers.Kfac(\n        learning_rate=3.0, damping=5.0, model=_simple_mlp(), **kwargs)\n    opt.learning_rate = 23.0\n    opt.damping = 27.0\n    config = opt.get_config()\n    self.assertEqual(config['learning_rate'], 23.0)\n    self.assertEqual(config['damping'], 27.0)\n    dense_approx = fisher_approx.pop(layers.Dense)\n    fisher_approx[utils._CLASS_NAME_PREFIX + 'Dense'] = dense_approx\n    for key, val in kwargs.items():\n      self.assertEqual(config[key], val)\n      # Below is how Keras's model.save saves the configs. If the config is not\n      # serializable, it will throw a TypeError or OverflowError.\n    json.dumps(config, default=serialization.get_json_type).encode('utf8')\n\n  @parameterized.named_parameters(('_LossName', {'loss': 'mse'}),\n                                  ('_LossFunction', {'loss': losses.MSE}))\n  def testFromConfig(self, kwargs_updates):\n    kwargs = {\n        'learning_rate': 3.0,\n        'damping': 5.0,\n        'momentum': 7.0,\n        'min_damping': 9.0,\n        'num_burnin_steps': 11,\n        'invert_every': 13,\n        'fisher_approx': {\n            layers.Dense: 'kron_in_diag',\n            'dense_1': 'kron_both_diag'\n        },\n    }\n    kwargs.update(kwargs_updates)\n    opt = optimizers.Kfac(model=_simple_mlp(), **kwargs)\n    config = opt.get_config()\n    config['name'] = 'diff_scope_name'\n    opt2 = optimizers.Kfac.from_config(config)\n    config2 = opt2.get_config()\n    config2.pop('name')\n    config.pop('name')\n    self.assertEqual(config, config2)\n    # Below is how Keras's model.save saves the configs. If the config is not\n    # serializable, it will throw a TypeError or OverflowError.\n    json.dumps(config, default=serialization.get_json_type).encode('utf8')\n    json.dumps(config2, default=serialization.get_json_type).encode('utf8')\n\n  @parameterized.named_parameters(('_Tensor', tf.convert_to_tensor),\n                                  ('_Float', float))\n  def testGettingHyper(self, hyper_ctor):\n    kwarg_values = {'learning_rate': 3.0, 'damping': 20.0, 'momentum': 13.0}\n    kwargs = {k: hyper_ctor(v) for k, v in kwarg_values.items()}\n    opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs)\n    get_value = backend.get_value\n    tf_opt = opt.optimizer\n    with self.subTest(name='MatchesFloat'):\n      for name, val in kwarg_values.items():\n        self.assertEqual(get_value(getattr(opt, name)), val)\n    with self.subTest(name='MatchesTfOpt'):\n      self.assertEqual(get_value(opt.lr), get_value(tf_opt.learning_rate))\n      self.assertEqual(get_value(opt.damping), get_value(tf_opt.damping))\n      self.assertEqual(get_value(opt.momentum), get_value(tf_opt.momentum))\n\n  def testGettingVariableHyperFails(self):\n    self.skipTest('This is not fixed in TF 1.14 yet.')\n    opt = optimizers.Kfac(model=_simple_mlp(),\n                          loss='mse',\n                          learning_rate=tf.Variable(0.1),\n                          damping=tf.Variable(0.1))\n    with self.assertRaisesRegex(tf.errors.FailedPreconditionError,\n                                '.*uninitialized.*'):\n      backend.get_value(opt.learning_rate)\n\n  @parameterized.named_parameters(\n      (('_' + name, name, float(val+1))\n       for val, name in enumerate(optimizers._MUTABLE_HYPER_PARAMS)))\n  def testSetTFVariableHyper(self, name, val):\n    kwargs = {'learning_rate': 0.01, 'damping': 0.001}\n    kwargs[name] = tf.Variable(45.0)\n    opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs)\n    setattr(opt, name, val)\n\n    with self.subTest(name='AssignedCorrectly'):\n      self.assertEqual(backend.get_value(getattr(opt, name)), val)\n      if hasattr(opt.optimizer, name):\n        self.assertEqual(backend.get_value(getattr(opt.optimizer, name)), val)\n\n    with self.subTest(name='SetError'):\n      with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'):\n        setattr(opt, name, tf.convert_to_tensor(2.0))\n      with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'):\n        setattr(opt, name, tf.Variable(2.0))\n\n  @parameterized.named_parameters(\n      (('_' + name, name, float(val + 1))\n       for val, name in enumerate(optimizers._MUTABLE_HYPER_PARAMS)))\n  def testSetFloatHyper(self, name, val):\n    kwargs = {'learning_rate': 0.01, 'damping': 0.001}\n    kwargs[name] = 45.0\n    opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs)\n    setattr(opt, name, val)\n\n    with self.subTest(name='AssignedCorrectly'):\n      self.assertEqual(backend.get_value(getattr(opt, name)), val)\n      if hasattr(opt.optimizer, name):\n        self.assertEqual(backend.get_value(getattr(opt.optimizer, name)), val)\n\n    with self.subTest(name='SetError'):\n      with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'):\n        setattr(opt, name, tf.convert_to_tensor(2.0))\n      with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'):\n        setattr(opt, name, tf.Variable(2.0))\n\n  @parameterized.named_parameters(\n      (('_' + name, name, float(val + 1))\n       for val, name in enumerate(optimizers._MUTABLE_HYPER_PARAMS)))\n  def testModifyingTensorHypersFails(self, name, val):\n    kwargs = {'learning_rate': 3.0, 'damping': 5.0, 'momentum': 7.0}\n    kwargs[name] = tf.convert_to_tensor(val)\n    opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs)\n    with self.subTest(name='AssignedCorrectly'):\n      self.assertEqual(backend.get_value(getattr(opt, name)), val)\n    with self.subTest(name='RaisesError'):\n      with self.assertRaisesRegex(AttributeError,\n                                  \"Can't set attribute: {}\".format(name)):\n        setattr(opt, name, 17)\n\n  def testLRBackwardsCompatibility(self):\n    \"\"\"This tests learning rate getting/setting used by old Keras callbacks.\"\"\"\n    opt = optimizers.Kfac(\n        learning_rate=3.0, damping=5.0, model=_simple_mlp(), loss='mse')\n    self.assertEqual(backend.get_value(opt.lr), 3.0)\n    self.assertEqual(backend.get_value(opt.learning_rate), 3.0)\n    opt.lr = 7.0\n    self.assertEqual(backend.get_value(opt.lr), 7.0)\n    self.assertEqual(backend.get_value(opt.learning_rate), 7.0)\n    backend.set_value(opt.lr, 9.0)\n    self.assertEqual(backend.get_value(opt.lr), 9.0)\n    self.assertEqual(backend.get_value(opt.learning_rate), 9.0)\n    backend.set_value(opt.learning_rate, 11.0)\n    self.assertEqual(backend.get_value(opt.lr), 11.0)\n    self.assertEqual(backend.get_value(opt.learning_rate), 11.0)\n\n  def testMultipleLossTraining(self):\n    inp = layers.Input(shape=(28, 28, 1))\n\n    branch1 = layers.Conv2D(13, 7, activation='relu')(inp)\n    branch1 = layers.GlobalMaxPool2D()(branch1)\n    branch1 = layers.Dense(1, name='path1')(branch1)\n\n    branch2 = layers.Conv2D(16, 3, activation='relu')(inp)\n    branch2 = layers.MaxPooling2D(pool_size=(4, 4))(branch2)\n    branch2 = layers.Flatten()(branch2)\n    branch2 = layers.Dense(9, name='path2')(branch2)\n\n    model = tf.keras.Model(inputs=inp, outputs=[branch1, branch2])\n    loss = {'path1': 'binary_crossentropy', 'path2': 'categorical_crossentropy'}\n    loss_weights = {'path1': 0.1, 'path2': 0.9}\n\n    (x, y), (valid_x, valid_y) = _get_synthetic_mnist_dataset()\n    y1, y2 = y[:, 0:1], y[:, 1:]\n    valid_y1, valid_y2 = valid_y[:, 0:1], valid_y[:, 1:]\n    data = (x, (y1, y2)), (valid_x, (valid_y1, valid_y2))\n\n    _train_model(data, model, loss, loss_weights=loss_weights)\n\n  @parameterized.named_parameters(('_LossName', 'categorical_crossentropy'),\n                                  ('_LossFunction', losses.binary_crossentropy))\n  def testRegisterLayersWithModel(self, loss):\n    model = _mnist_model()\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.001)\n    opt.register_layers(model=model, loss=loss)\n    model.compile(optimizer=opt, loss=loss)\n    opt.get_updates(model.total_loss, model.trainable_weights)\n\n  def testRegisterLayersWithLayerCollection(self):\n    model, loss = _mnist_model(), 'categorical_crossentropy'\n    lc = utils.get_layer_collection(model, loss)\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.001)\n    opt.register_layers(layer_collection=lc)\n    model.compile(optimizer=opt, loss=loss)\n    opt.get_updates(model.total_loss, model.trainable_weights)\n\n  @parameterized.named_parameters(('_LossName', 'categorical_crossentropy'),\n                                  ('_LossFunction', losses.binary_crossentropy))\n  def testRegisterLayersCompiledModel(self, loss):\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.001)\n    model = _mnist_model()\n    model.compile(optimizer=opt, loss=loss)\n    opt.register_layers(model=model)\n    model.compile(optimizer=opt, loss=loss)\n    opt.get_updates(model.total_loss, model.trainable_weights)\n\n  def testTrainWithoutCreatingOptimizerFails(self):\n    with self.assertRaisesRegex(ValueError, '.*provide a model with a loss.*'):\n      opt = optimizers.Kfac(learning_rate=0.01, damping=0.001)\n      model = _mnist_model()\n      model.compile(optimizer=opt, loss='categorical_crossentropy')\n      grads_vars = opt.get_gradients(model.total_loss, model.trainable_weights)\n      opt.apply_gradients(grads_vars)\n\n  def testEmptyCreateKfacOptimizerFails(self):\n    with self.assertRaisesRegex(ValueError, '.*provide a model with a loss.*'):\n      opt = optimizers.Kfac(learning_rate=0.01, damping=0.001)\n      opt._create_optimizer()\n\n  def testSeed(self):\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                              model=_simple_mlp(), loss='mse', seed=4321)\n    lc = opt.optimizer.layers\n    self.assertEqual(lc._loss_dict['squared_error_loss'][0]._default_seed, 4321)\n\n  def testNewOptSameVarScope(self):\n    model = _simple_mlp()\n    opt = optimizers.Kfac(\n        learning_rate=0.01, damping=0.01, model=model, loss='mse')\n    opt._create_optimizer()\n    opt2 = optimizers.Kfac(\n        learning_rate=0.02, damping=0.03, model=model, loss='mse')\n    opt2._create_optimizer()\n\n  def testGetSetWeights(self):\n    def model_maker():\n      return tf.keras.Sequential([layers.Dense(2, input_shape=(3,))])\n\n    x = np.random.random((1, 3))\n    y = np.random.random((1, 2))\n    loss = 'mse'\n    model = model_maker()\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.1,\n                              model=model, loss=loss, seed=1234)\n    model.compile(optimizer=opt, loss=loss)\n    model.train_on_batch(x, y)\n    opt_weights = opt.get_weights()\n\n    self.assertEqual(1, opt_weights[0])  # iterations\n    self.assertEqual(1, opt_weights[6])  # counter\n    self.assertEqual(0, opt_weights[7])  # burn in counter\n\n    config = opt.get_config()\n    config['name'] = 'diff_name'\n    opt2 = optimizers.Kfac.from_config(config)\n    model2 = model_maker()\n    model2.compile(optimizer=opt2, loss=loss)\n    opt2.register_layers(model=model2)\n    # Set weights should only work after a call to get_updates/apply_gradients.\n    x = np.random.random((1, 3))\n    y = np.random.random((1, 2))\n    model2.train_on_batch(x, y)\n    opt2.set_weights(opt_weights)\n\n    for w1, w2 in zip(opt_weights, opt2.get_weights()):\n      self.assertAllClose(w1, w2)\n\n    model2.set_weights(model.get_weights())\n    x = np.random.random((1, 3))\n    y = np.random.random((1, 2))\n    model.train_on_batch(x, y)\n    model2.train_on_batch(x, y)\n\n    for w1, w2 in zip(opt.get_weights(), opt2.get_weights()):\n      self.assertAllClose(w1, w2)\n\n  @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False))\n  def testTrainModelWithNormalization(self, has_shift):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 1)),\n        layers.BatchNormalization(center=has_shift, fused=False),\n        layers.Conv2D(23, 3),\n        layers.LayerNormalization(center=has_shift),\n        layers.GlobalMaxPool2D(),\n        layers.Dense(10, activation='softmax')\n    ])\n    (x_train, y_train), _ = _get_synthetic_mnist_dataset()\n    approx = {layers.LayerNormalization: 'full'}\n    loss = 'categorical_crossentropy'\n    opt = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                              model=model, loss=loss, fisher_approx=approx)\n    model.compile(opt, loss)\n    return model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0)\n\n  @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False))\n  def testTrainModelWithFusedBN(self, has_shift):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 1)),\n        layers.BatchNormalization(center=has_shift, fused=True),\n        layers.GlobalMaxPool2D(),\n        layers.Dense(10, activation='softmax')\n    ])\n    (x_train, y_train), _ = _get_synthetic_mnist_dataset()\n    loss = 'categorical_crossentropy'\n    opt = optimizers.Kfac(\n        learning_rate=0.01, damping=0.01, model=model, loss=loss)\n    model.compile(opt, loss)\n    return model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0)\n\n  @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False))\n  def testTrainModelWithFusedBNAndLearningPhase(self, has_shift):\n    tf.keras.backend.set_learning_phase(1)\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 1)),\n        layers.BatchNormalization(center=has_shift, fused=True),\n        layers.GlobalMaxPool2D(),\n        layers.Dense(10, activation='softmax')\n    ])\n    (x_train, y_train), _ = _get_synthetic_mnist_dataset()\n    loss = 'categorical_crossentropy'\n    opt = optimizers.Kfac(\n        learning_rate=0.01, damping=0.01, model=model, loss=loss)\n    model.compile(opt, loss)\n    return model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0)\n\n  @parameterized.named_parameters(('_WithShape', {'input_shape': (28, 28, 1)}),\n                                  ('_WithoutShape', {}))\n  def testCustomTrainingLoopSequential(self, input_conv_kwargs):\n    # Without the input_shape the only inbound node is the correct one, with the\n    # input_shape there are two, and we want the second one.\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, **input_conv_kwargs),\n        layers.BatchNormalization(fused=False),\n        layers.Conv2D(23, 3),\n        layers.LayerNormalization(),\n        layers.GlobalMaxPool2D(),\n        layers.Dense(10, activation='softmax', name='output_test')\n    ])\n    x, y = _get_synthetic_mnist_train_tensors(batch_size=10)\n    model_input = tf.keras.Input(tensor=x)\n    output = model(model_input)\n    loss = tf.keras.losses.binary_crossentropy(output, y)\n    optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                                    model=model, loss='binary_crossentropy')\n    train_op = optimizer.minimize(loss, var_list=model.trainable_weights)\n    with tf.Session() as sess:\n      sess.run(tf.global_variables_initializer())\n      for _ in range(3):\n        sess.run([train_op])\n\n  def testCustomTrainingLoopFunctionalInpTensor(self):\n    # This case should work trivially--the only inbound node is the correct one.\n    x, y = _get_synthetic_mnist_train_tensors(batch_size=10)\n\n    # Build Model\n    inp = tf.keras.Input(tensor=x)\n    x = layers.Conv2D(13, 5)(inp)\n    x = layers.BatchNormalization(fused=False)(x)\n    x = layers.Conv2D(23, 3)(x)\n    x = layers.LayerNormalization()(x)\n    x = layers.GlobalMaxPool2D()(x)\n    out = layers.Dense(10, activation='softmax', name='output_test')(x)\n    model = tf.keras.Model(inputs=inp, outputs=out)\n\n    loss = tf.keras.losses.binary_crossentropy(model.output, y)\n    optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                                    model=model, loss='binary_crossentropy')\n    train_op = optimizer.minimize(loss, var_list=model.trainable_weights)\n    with tf.Session() as sess:\n      sess.run(tf.global_variables_initializer())\n      for _ in range(3):\n        sess.run([train_op])\n\n  def testCustomTrainingLoopFunctionalInpShape(self):\n    # We need to ensure correct inbound node is used for layer collection.\n    x, y = _get_synthetic_mnist_train_tensors(batch_size=10)\n    model_input = tf.keras.Input(tensor=x)\n\n    # Build Model\n    inp = tf.keras.Input(shape=(28, 28, 1))\n    x = layers.Conv2D(13, 5)(inp)\n    x = layers.BatchNormalization(fused=True)(x)\n    x = layers.Activation('relu')(x)\n    x = layers.Conv2D(23, 3)(x)\n    x = layers.LayerNormalization()(x)\n    x = layers.GlobalMaxPool2D()(x)\n    out = layers.Dense(10, activation='softmax', name='output_test')(x)\n    model = tf.keras.Model(inputs=inp, outputs=out)\n\n    output = model(model_input)\n    loss = tf.keras.losses.binary_crossentropy(output, y)\n    optimizer = optimizers.Kfac(damping=0.01, learning_rate=0.01,\n                                    model=model, loss='binary_crossentropy')\n    train_op = optimizer.minimize(loss, var_list=model.trainable_weights)\n    with tf.Session() as sess:\n      sess.run(tf.global_variables_initializer())\n      for _ in range(3):\n        sess.run([train_op])\n\n  def testCustomTrainingLoopMakeOptimizerBeforeModelCall(self):\n    # We defer the creation of the layer_collection to the minimize call for\n    # this situation, because if we make the layer_collection immediately it\n    # will capture the wrong inbound node.\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5),\n        layers.BatchNormalization(fused=False),\n        layers.Conv2D(23, 3),\n        layers.LayerNormalization(),\n        layers.GlobalMaxPool2D(),\n        layers.Dense(10, activation='softmax', name='output_test')\n    ])\n    optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                                    model=model, loss='binary_crossentropy')\n    x, y = _get_synthetic_mnist_train_tensors(batch_size=10)\n    model_input = tf.keras.Input(tensor=x)\n    output = model(model_input)\n    loss = tf.keras.losses.binary_crossentropy(output, y)\n    train_op = optimizer.minimize(loss, var_list=model.trainable_weights)\n    with self.cached_session() as sess:\n      sess.run(tf.global_variables_initializer())\n      for _ in range(3):\n        sess.run([train_op])\n\n  def testCustomTrainingUnwrappedTensorFails(self):\n    # This test does not test our implementation, but is here so if Keras ever\n    # adds functionality to support raw tensors as Nodes, this test will fail\n    # and we can remove the restriction from our documentation.\n    model = _simple_mlp()\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat().batch(10)\n    x, y = dataset.make_one_shot_iterator().get_next()\n    pred = model(x)\n    loss = tf.keras.losses.binary_crossentropy(pred, y)\n    optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01,\n                                    model=model, loss='binary_crossentropy')\n    train_op = optimizer.minimize(loss, var_list=model.trainable_weights)\n    with self.cached_session() as sess:\n      sess.run(tf.global_variables_initializer())\n      with self.assertRaisesRegex(tf.errors.InvalidArgumentError,\n                                  '.*You must feed a value for placeholder.*'):\n        sess.run([train_op])\n\n  def testTrainingNestedModel(self):\n    inputs = tf.keras.Input(shape=(1,))\n    y1 = _simple_mlp()(inputs)\n    y2 = _simple_mlp()(inputs)\n    y3 = _simple_mlp()(inputs)\n    outputs = layers.average([y1, y2, y3])\n    ensemble_model = tf.keras.Model(inputs=inputs, outputs=outputs)\n\n    optimizer = optimizers.Kfac(learning_rate=0.01,\n                                    damping=0.01,\n                                    model=ensemble_model,\n                                    loss='binary_crossentropy')\n    ensemble_model.compile(optimizer, 'binary_crossentropy')\n\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat().batch(10)\n    x, y = dataset.make_one_shot_iterator().get_next()\n    ensemble_model.train_on_batch(x, y)\n\n  def testCustomTrainLoopNestedModel(self):\n    inputs = tf.keras.Input(shape=(1,))\n    y1 = _simple_mlp()(inputs)\n    y2 = _simple_mlp()(inputs)\n    y3 = _simple_mlp()(inputs)\n    outputs = layers.average([y1, y2, y3])\n    ensemble_model = tf.keras.Model(inputs=inputs, outputs=outputs)\n\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat().batch(10)\n    x, y = dataset.make_one_shot_iterator().get_next()\n    x = layers.Input(tensor=x)\n\n    optimizer = optimizers.Kfac(learning_rate=0.01,\n                                    damping=0.01,\n                                    model=ensemble_model,\n                                    loss='binary_crossentropy')\n\n    pred = ensemble_model(x)\n    loss = tf.keras.losses.binary_crossentropy(pred, y)\n    train_op = optimizer.minimize(\n        loss, var_list=ensemble_model.trainable_weights)\n    with self.cached_session() as sess:\n      sess.run(tf.global_variables_initializer())\n      sess.run([train_op])\n\n  @parameterized.named_parameters(\n      ('_NoKwargs', {'norm_constraint'}, {}),\n      ('_MomentumNormKwargs',\n       set(),\n       {'momentum': 1, 'norm_constraint': 2}),\n      ('_QModel',\n       {'momentum', 'learning_rate', 'norm_constraint'},\n       {'momentum': None, 'momentum_type': 'qmodel', 'learning_rate': None}),\n      ('_AdaptiveDamping',\n       {'damping', 'norm_constraint'},\n       {'adapt_damping': True, 'damping_adaptation_interval': 20}))\n  def testMutableHypers(self, not_mutable, kwargs_update):\n    kwargs = {'learning_rate': 0.01, 'damping': 0.001}\n    kwargs.update(kwargs_update)\n    opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs)\n    mutable = optimizers._MUTABLE_HYPER_PARAMS - not_mutable\n    self.assertEqual(set(opt.mutable_hyperparameters), mutable)\n\n  def testPositionalArgsFail(self):\n    with self.assertRaisesRegex(ValueError,\n                                'Do not pass positional arguments.*'):\n      optimizers.Kfac(0.1, 0.1, model=_simple_mlp(), loss='mse')\n\n  def testSettingName(self):\n    model = _simple_mlp()\n    optimizer = optimizers.Kfac(damping=0.01, learning_rate=0.01,\n                                    model=model, loss='mse')\n    optimizer.name = 'new_name'\n    self.assertEqual(optimizer._name, 'new_name')\n    self.assertEqual(optimizer.get_config()['name'], 'new_name')\n    self.assertEqual(optimizer._kfac_kwargs['name'], 'new_name')\n    model.compile(optimizer, 'mse')\n    model._make_train_function()\n    with self.assertRaisesRegex(ValueError,\n                                '.*after the variables are created.*'):\n      optimizer.name = 'another_name'\n\n  @parameterized.named_parameters(\n      ('_AdaptDamping', {'adapt_damping': True, 'learning_rate': 0.1}),\n      ('_Adaptive', {'adaptive': True, 'qmodel_update_rescale': 0.01}))\n  def testAdaptiveModelFit(self, adaptive_kwargs):\n    rands = lambda: np.random.random((100, 1)).astype(np.float32)\n    dataset = tf.data.Dataset.from_tensor_slices((rands(), rands()))\n    dataset = dataset.repeat().batch(10, drop_remainder=True)\n    train_batch = dataset.make_one_shot_iterator().get_next()\n\n    model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n    loss = 'mse'\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    optimizer = optimizers.Kfac(damping=10.,\n                                train_batch=train_batch,\n                                model=model,\n                                loss=loss,\n                                **adaptive_kwargs)\n    model.compile(optimizer, loss)\n    model.fit(train_batch, steps_per_epoch=10, epochs=1)\n\n  @parameterized.named_parameters(('_Fused', True), ('_NotFused', False))\n  def testAdaptiveModelFitBatchnorm(self, is_fused):\n    train_batch = _get_synthetic_mnist_train_tensors(drop_remainder=True)\n    model =  tf.keras.Sequential([\n      layers.Conv2D(13, 5, input_shape=(28,28,1)),\n      layers.BatchNormalization(fused=is_fused),\n      layers.Conv2D(23, 3),\n      layers.LayerNormalization(),\n      layers.GlobalMaxPool2D(),\n      layers.Dense(10, activation='softmax', name='output_test')\n    ])\n    loss = 'categorical_crossentropy'\n    optimizer = optimizers.Kfac(damping=10.,\n                                adaptive=True,\n                                train_batch=train_batch,\n                                model=model,\n                                loss=loss)\n    model.compile(optimizer, loss)\n    model.train_on_batch(x=train_batch[0], y=train_batch[1])\n\n  def testInferredBatchSize(self):\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.]))\n    dataset = dataset.repeat().batch(11, drop_remainder=True)\n    train_batch = dataset.make_one_shot_iterator().get_next()\n\n    model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n    loss = 'mse'\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    optimizer = optimizers.Kfac(damping=10.,\n                                train_batch=train_batch,\n                                model=model,\n                                adaptive=True,\n                                loss=loss,\n                                qmodel_update_rescale=0.01)\n    model.compile(optimizer, loss)\n    model.train_on_batch(train_batch[0], train_batch[1])\n    self.assertEqual(\n        tf.keras.backend.get_value(optimizer.optimizer._batch_size), 11)\n\n  @parameterized.named_parameters(('_Adaptive', {'adaptive': True}),\n                                  ('_AdaptDamping', {'adapt_damping': True}))\n  def testInferredBatchSizeFail(self, kfac_kwargs):\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.]))\n    dataset = dataset.repeat().batch(11, drop_remainder=False)\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    with self.assertRaisesRegex(ValueError, 'Could not infer batch_size.*'):\n      optimizer = optimizers.Kfac(damping=10.,\n                                  train_batch=train_batch,\n                                  **kfac_kwargs)\n\n  def testOverrideAdaptiveDefaults(self):\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.]))\n    dataset = dataset.repeat().batch(11, drop_remainder=False)\n    train_batch = dataset.make_one_shot_iterator().get_next()\n\n    model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n    loss = 'mse'\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    optimizer = optimizers.Kfac(damping=10.,\n                                adaptive=True,\n                                train_batch=train_batch,\n                                model=model,\n                                batch_size=11,\n                                invert_every=1,\n                                damping_adaptation_interval=2,\n                                loss=loss,\n                                qmodel_update_rescale=0.01)\n    model.compile(optimizer, loss)\n    model.train_on_batch(train_batch[0], train_batch[1])\n    self.assertEqual(optimizer.optimizer._invert_every, 1)\n    self.assertEqual(optimizer.optimizer._damping_adaptation_interval, 2)\n\n  @parameterized.named_parameters(('_Adaptive', {'adaptive': True}),\n                                  ('_Qmodel', {'momentum_type': 'qmodel'}))\n  def testAdaptiveWithLR(self, kfac_kwargs):\n    dataset = tf.data.Dataset.from_tensors(([1.], [1.]))\n    dataset = dataset.repeat().batch(11, drop_remainder=True)\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    with self.assertRaisesRegex(ValueError, 'learning_rate must be None.*'):\n      optimizer = optimizers.Kfac(damping=10.,\n                                  train_batch=train_batch,\n                                  learning_rate=0.1,\n                                  **kfac_kwargs)\n\n  def testCustomLossFn(self):\n    rands = lambda: np.random.random((100, 1)).astype(np.float32)\n    dataset = tf.data.Dataset.from_tensor_slices((rands(), rands()))\n    dataset = dataset.repeat().batch(10, drop_remainder=True)\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n\n    def loss_fn(inputs):\n      mse = tf.keras.losses.mean_squared_error(model(inputs[0]), inputs[1])\n      return tf.reduce_mean(mse)\n\n    loss = 'mse'\n    train_batch = dataset.make_one_shot_iterator().get_next()\n    optimizer = optimizers.Kfac(damping=10.,\n                                train_batch=train_batch,\n                                adaptive=True,\n                                model=model,\n                                loss=loss,\n                                loss_fn=loss_fn,\n                                qmodel_update_rescale=0.01)\n    model.compile(optimizer, loss)\n    model.fit(train_batch, steps_per_epoch=10, epochs=1)\n    self.assertEqual(loss_fn, optimizer.optimizer._loss_fn)\n\n  def testRegisterTrainBatch(self):\n    model =  tf.keras.Sequential([\n      layers.Conv2D(13, 5, input_shape=(28,28,1)),\n      layers.BatchNormalization(),\n      layers.Conv2D(23, 3),\n      layers.LayerNormalization(),\n      layers.GlobalMaxPool2D(),\n      layers.Dense(10, activation='softmax', name='output_test')\n    ])\n    loss = 'categorical_crossentropy'\n    optimizer = optimizers.Kfac(damping=10.,\n                                adaptive=True,\n                                model=model,\n                                loss=loss)\n    model.compile(optimizer, loss)\n    train_batch = _get_synthetic_mnist_train_tensors(drop_remainder=True)\n    optimizer.register_train_batch(train_batch)\n    model.train_on_batch(x=train_batch[0], y=train_batch[1])\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/keras_saving_utils_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#,============================================================================\n\"\"\"Tests for keras/saving_utils.py.\n\nThese tests were forked from the hdf5_format_test.py tests in Keras.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport tempfile\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.framework import test_util\nfrom kfac.python.keras import optimizers\nfrom kfac.python.keras import saving_utils\n\ntry:\n  import h5py  # pylint:disable=g-import-not-at-top\nexcept ImportError:\n  h5py = None\n\nkeras = tf.keras\n_KFAC_KWARGS = {\n    'learning_rate': 0.0001,\n    'damping': 0.01,\n    'momentum': 0.85,\n    'fisher_approx': {\n        keras.layers.Dense: 'kron_in_diag',\n    },\n    'loss': 'mse',\n    # This seed is necessary to keep the optimizer updates deterministic, since\n    # we're approximating the true Fisher by sampling the targets. Since for\n    # many tests we only do one training step, the approximations can vary\n    # significantly without a set seed.\n    'seed': 1234,\n}\n\n\nclass SavingUtilsTest(tf.test.TestCase):\n\n  @test_util.run_v1_only('b/120994067')\n  def test_sequential_model_saving(self):\n    if h5py is None:\n      self.skipTest('h5py required to run this test')\n\n    with self.cached_session():\n      model = keras.models.Sequential()\n      model.add(keras.layers.Dense(2, input_shape=(2,)))\n      model.add(keras.layers.RepeatVector(3))\n      model.add(keras.layers.Flatten())\n      model.add(keras.layers.Dense(3))\n      model.compile(\n          loss=keras.losses.MSE,\n          optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS),\n          metrics=[\n              keras.metrics.categorical_accuracy,\n              keras.metrics.CategoricalAccuracy()\n          ])\n\n      x = np.random.random((1, 2))\n      y = np.random.random((1, 3))\n\n      # TODO(b/136561651): Since we use TFP distributions to sample from the\n      # output distribution, optimizer's won't match exactly unless they are run\n      # for the same number of steps. Even with a random seed, the internal\n      # state of TFP changes with each call. We must switch to a stateless\n      # sampler. Uncomment the train line below once this is implemented.\n      # model.train_on_batch(x, y)\n\n      out = model.predict(x)\n      fd, fname = tempfile.mkstemp('.h5')\n      keras.models.save_model(model, fname)\n\n      new_model = saving_utils.load_model(fname, optimizer_name='new')\n      os.close(fd)\n      os.remove(fname)\n\n      out2 = new_model.predict(x)\n      self.assertAllClose(out, out2, atol=1e-05)\n\n      # test that new updates are the same with both models\n      x = np.random.random((1, 2))\n      y = np.random.random((1, 3))\n      model.train_on_batch(x, y)\n      new_model.train_on_batch(x, y)\n\n      x = np.random.random((1, 2))\n      y = np.random.random((1, 3))\n      eval_out = model.evaluate(x, y)\n      eval_out2 = new_model.evaluate(x, y)\n      self.assertArrayNear(eval_out, eval_out2, 1e-03)\n\n      out = model.predict(x)\n      out2 = new_model.predict(x)\n\n      self.assertAllClose(out, out2, atol=1e-05)\n\n  @test_util.run_deprecated_v1\n  def test_functional_model_saving(self):\n    if h5py is None:\n      self.skipTest('h5py required to run this test')\n\n    with self.cached_session():\n      inputs = keras.layers.Input(shape=(3,))\n      x = keras.layers.Dense(2)(inputs)\n      output = keras.layers.Dense(3)(x)\n\n      model = keras.models.Model(inputs, output)\n      model.compile(\n          loss=keras.losses.MSE,\n          optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS),\n          metrics=[\n              keras.metrics.categorical_accuracy,\n              keras.metrics.CategoricalAccuracy()\n          ],\n          weighted_metrics=[\n              keras.metrics.categorical_accuracy,\n              keras.metrics.CategoricalAccuracy()\n          ])\n      x = np.random.random((1, 3))\n      y = np.random.random((1, 3))\n      model.train_on_batch(x, y)\n\n      out = model.predict(x)\n      fd, fname = tempfile.mkstemp('.h5')\n      keras.models.save_model(model, fname)\n\n      model = saving_utils.load_model(fname, optimizer_name='new')\n      os.close(fd)\n      os.remove(fname)\n\n      out2 = model.predict(x)\n      self.assertAllClose(out, out2, atol=1e-05)\n\n  def test_saving_model_with_long_layer_names(self):\n    if h5py is None:\n      self.skipTest('h5py required to run this test')\n\n    with self.cached_session():\n      # This layer name will make the `layers_name` HDF5 attribute blow\n      # out of proportion. Note that it fits into the internal HDF5\n      # attribute memory limit on its own but because h5py converts\n      # the list of layer names into numpy array, which uses the same\n      # amount of memory for every item, it increases the memory\n      # requirements substantially.\n      x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))\n      f = x\n      for i in range(4):\n        f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)\n      model = keras.Model(inputs=[x], outputs=[f])\n      model.compile(optimizers.Kfac(model=model, **_KFAC_KWARGS),\n                    loss=keras.losses.MeanSquaredError(),\n                    metrics=['acc'])\n\n      x = np.random.random((1, 2))\n      y = np.random.random((1, 2))\n      model.train_on_batch(x, y)\n      out = model.predict(x)\n\n      fd, fname = tempfile.mkstemp('.h5')\n      keras.models.save_model(model, fname)\n      model = saving_utils.load_model(fname, optimizer_name='new')\n\n      # Check that the HDF5 files contains chunked array\n      # of layer names.\n      with h5py.File(fname, 'r') as h5file:\n        num_names_arrays = len([attr for attr in h5file['model_weights'].attrs\n                                if attr.startswith('layer_names')])\n      # The chunking of layer names array should have happened.\n      self.assertGreater(num_names_arrays, 0)\n      out2 = model.predict(x)\n      self.assertAllClose(out, out2, atol=1e-05)\n\n      # Cleanup\n      os.close(fd)\n      os.remove(fname)\n\n  def test_saving_model_with_long_weights_names(self):\n    self.skipTest('KFAC does not support nested models yet.')\n    if h5py is None:\n      self.skipTest('h5py required to run this test')\n\n    with self.cached_session():\n      x = keras.Input(shape=(2,), name='nested_model_input')\n      f = x\n      for i in range(4):\n        f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)\n      # This layer name will make the `weights_name`\n      # HDF5 attribute blow out of proportion.\n      f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f)\n      nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')\n\n      x = keras.Input(shape=(2,), name='outer_model_input')\n      f = nested_model(x)\n      f = keras.layers.Dense(2, name='outer_model_output')(f)\n\n      model = keras.Model(inputs=[x], outputs=[f])\n      model.compile(loss='mse',\n                    optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS),\n                    metrics=['acc'])\n\n      x = np.random.random((1, 2))\n      y = np.random.random((1, 2))\n      model.train_on_batch(x, y)\n      out = model.predict(x)\n\n      fd, fname = tempfile.mkstemp('.h5')\n      keras.models.save_model(model, fname)\n      model = saving_utils.load_model(fname, optimizer_name='new')\n\n      # Check that the HDF5 files contains chunked array\n      # of weight names.\n      with h5py.File(fname, 'r') as h5file:\n        num_weight_arrays = len(\n            [attr for attr in h5file['model_weights']['nested_model'].attrs\n             if attr.startswith('weight_names')])\n      # The chunking of layer names array should have happened.\n      self.assertGreater(num_weight_arrays, 0)\n      out2 = model.predict(x)\n      self.assertAllClose(out, out2, atol=1e-05)\n\n      # Cleanup\n      os.close(fd)\n      os.remove(fname)\n\n  @test_util.run_deprecated_v1\n  def test_model_saving_to_pre_created_h5py_file(self):\n    if h5py is None:\n      self.skipTest('h5py required to run this test')\n\n    with self.cached_session():\n      inputs = keras.Input(shape=(3,))\n      x = keras.layers.Dense(2)(inputs)\n      outputs = keras.layers.Dense(3)(x)\n\n      model = keras.Model(inputs, outputs)\n      model.compile(\n          loss=keras.losses.MSE,\n          optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS),\n          metrics=[\n              keras.metrics.categorical_accuracy,\n              keras.metrics.CategoricalAccuracy()\n          ])\n      x = np.random.random((1, 3))\n      y = np.random.random((1, 3))\n      model.train_on_batch(x, y)\n\n      out = model.predict(x)\n      fd, fname = tempfile.mkstemp('.h5')\n      with h5py.File(fname, mode='r+') as h5file:\n        keras.models.save_model(model, h5file)\n        loaded_model = saving_utils.load_model(h5file, optimizer_name='new')\n        out2 = loaded_model.predict(x)\n      self.assertAllClose(out, out2, atol=1e-05)\n\n      # Test non-default options in h5\n      with h5py.File(\n          '-', driver='core', mode='w', backing_store=False) as h5file:\n        keras.models.save_model(model, h5file)\n        loaded_model = saving_utils.load_model(h5file, optimizer_name='new2')\n        out2 = loaded_model.predict(x)\n      self.assertAllClose(out, out2, atol=1e-05)\n\n      # Cleanup\n      os.close(fd)\n      os.remove(fname)\n\n  def test_saving_constant_initializer_with_numpy(self):\n    if h5py is None:\n      self.skipTest('h5py required to run this test')\n\n    with self.cached_session():\n      model = keras.models.Sequential()\n      model.add(\n          keras.layers.Dense(\n              2,\n              input_shape=(3,),\n              kernel_initializer=keras.initializers.Constant(np.ones((3, 2)))))\n      model.add(keras.layers.Dense(3))\n      model.compile(loss='mse',\n                    optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS),\n                    metrics=['acc'])\n      fd, fname = tempfile.mkstemp('.h5')\n      keras.models.save_model(model, fname)\n      model = saving_utils.load_model(fname, optimizer_name='new')\n      os.close(fd)\n      os.remove(fname)\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/keras_utils_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for keras/utils.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.keras import utils\nfrom kfac.python.ops import fisher_blocks\nfrom kfac.python.ops import loss_functions\n\nlayers = tf.keras.layers\nlosses = tf.keras.losses\n_SEED = 1234\n\n\ndef _mlp():\n  return tf.keras.Sequential([\n      layers.Embedding(100, 13, input_length=1),\n      layers.Flatten(),\n      layers.Dense(32, activation='tanh'),\n      layers.Dense(32, activation='tanh'),\n      layers.Dense(1)\n  ])\n\n\ndef _cnn():\n  return tf.keras.Sequential([\n      layers.Conv2D(7, 5, input_shape=(28, 28, 3)),\n      layers.Activation('relu'),\n      layers.Conv2D(13, (3, 3), activation='relu'),\n      layers.GlobalMaxPool2D(),\n      layers.Activation('softmax')\n  ])\n\n\ndef _two_loss_model(num_branch1_outputs=1, num_branch2_outputs=9):\n  inp = layers.Input(shape=(28, 28, 1))\n\n  branch1 = layers.Lambda(lambda x: tf.squeeze(x, -1))(inp)\n  branch1 = layers.Conv1D(13, 7, activation='relu')(branch1)\n  branch1 = layers.GlobalMaxPool1D()(branch1)\n  branch1 = layers.Dense(num_branch1_outputs, name='out1')(branch1)\n\n  branch2 = layers.Conv2D(16, 3, activation='relu')(inp)\n  branch2 = layers.MaxPooling2D(pool_size=(4, 4))(branch2)\n  branch2 = layers.Flatten()(branch2)\n  branch2 = layers.Dense(num_branch2_outputs, name='out2')(branch2)\n\n  return inp, (branch1, branch2)\n\n\nclass GetLayerCollectionTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(GetLayerCollectionTest, self).setUp()\n    tf.reset_default_graph()\n    tf.random.set_random_seed(_SEED)\n\n  @parameterized.named_parameters(\n      ('_Categorical', 'categorical_crossentropy',\n       loss_functions.CategoricalLogitsNegativeLogProbLoss),\n      ('_Binary', 'binary_crossentropy',\n       loss_functions.MultiBernoulliNegativeLogProbLoss),\n      ('_Sparse', losses.sparse_categorical_crossentropy,\n       loss_functions.CategoricalLogitsNegativeLogProbLoss))\n  def testValidLogitLossFunctionsCNN(self, loss, kfac_loss):\n    \"\"\"Ensures correct tensorflow_kfac loss function and variable for a CNN.\n\n    Args:\n      loss: A losses function (in serialized form or actual reference)\n      kfac_loss: tensorflow_kfac.python.ops loss function.\n    \"\"\"\n    with tf.Graph().as_default():\n      model = _cnn()\n      lc = utils.get_layer_collection(model, loss)\n      self.assertIsInstance(lc.losses[0], kfac_loss)\n      self.assertEqual(lc.losses[0].params,\n                       utils.get_parent(model.layers[-1].output))\n\n  @parameterized.named_parameters(\n      ('_Categorical', 'categorical_crossentropy',\n       loss_functions.CategoricalLogitsNegativeLogProbLoss),\n      ('_Binary', 'binary_crossentropy',\n       loss_functions.MultiBernoulliNegativeLogProbLoss),\n      ('_Sparse', losses.sparse_categorical_crossentropy,\n       loss_functions.CategoricalLogitsNegativeLogProbLoss))\n  def testValidLogitLossFunctionsMLP(self, loss, kfac_loss):\n    \"\"\"Ensures correct tensorflow_kfac loss function and variable for a MLP.\n\n    Args:\n      loss: A losses function (in serialized form or actual reference)\n      kfac_loss: tensorflow_kfac.python.ops loss function.\n    \"\"\"\n    with tf.Graph().as_default():\n      model = _mlp()\n      lc = utils.get_layer_collection(model, loss)\n      self.assertIsInstance(lc.losses[0], kfac_loss)\n      self.assertEqual(lc.losses[0].params, model.layers[-1].output)\n\n  @parameterized.named_parameters(('_LongCNN', 'mean_squared_error', _cnn),\n                                  ('ShortCNN', 'mse', _cnn),\n                                  ('_LongMLP', losses.mean_squared_error, _mlp),\n                                  ('ShortMLP', 'mse', _mlp),\n                                  ('_Class', losses.MeanSquaredError(), _mlp))\n  def testValidMSE(self, loss, model_builder):\n    \"\"\"Ensures variations of MSE and output variables work.\n\n    Args:\n      loss: A tf.keras.losses function (in serialized form or actual reference)\n      model_builder: Function that returns a Keras model.\n    \"\"\"\n    model = model_builder()\n    lc = utils.get_layer_collection(model, loss)\n    self.assertIsInstance(lc.losses[0],\n                          loss_functions.NormalMeanNegativeLogProbLoss)\n    self.assertEqual(lc.losses[0].params, model.layers[-1].output)\n\n  @parameterized.named_parameters(('_NotRealLoss', 'blah blah blah'),\n                                  ('_RealButInvalid', 'cosine'),\n                                  ('_SimilarName', 'msle'))\n  def testInvalidLossFunctions(self, loss):\n    with self.assertRaisesRegex(ValueError, '.*loss function:.*'):\n      model = _mlp()\n      utils.get_layer_collection(model, loss)\n\n  @parameterized.named_parameters(('_CNN', _cnn), ('_MLP', _mlp))\n  def testLayerRegistration(self, model_builder):\n    model = model_builder()\n    model.layers[0].trainable = False\n\n    lc = utils.get_layer_collection(model, 'mse')\n    registered = set(lc.registered_variables)\n\n    variables = set()\n    for layer in model.layers[1:]:\n      if layer.trainable and layer.count_params():\n        variables |= set(layer.weights)\n\n    self.assertEqual(registered, variables)\n\n  @parameterized.named_parameters(\n      ('_DictLoss',\n       {'out1': 'binary_crossentropy', 'out2': 'categorical_crossentropy'},\n       {'out1': 0.1, 'out2': 0.9}),\n      ('_ListLoss',\n       ['binary_crossentropy', 'categorical_crossentropy'],\n       [0.1, 0.9]))\n  def testMultipleLoss(self, loss, loss_weights):\n    inputs, (out1, out2) = _two_loss_model()\n    model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])\n    lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights)\n\n    self.assertLen(lc.loss_coeffs.keys(), 2)\n    self.assertLen(lc.loss_colocation_ops.keys(), 2)\n\n    l1 = lc._loss_dict['sigmoid_cross_entropy_loss']\n    l2 = lc._loss_dict['sparse_softmax_cross_entropy_loss']\n\n    self.assertLen(l1, 1)\n    self.assertLen(l2, 1)\n\n    l1, l2 = l1[0], l2[0]\n\n    self.assertIsInstance(l1,\n                          loss_functions.MultiBernoulliNegativeLogProbLoss)\n    self.assertIsInstance(l2,\n                          loss_functions.CategoricalLogitsNegativeLogProbLoss)\n    self.assertEqual(lc.loss_coeffs[l1], 0.1)\n    self.assertEqual(lc.loss_coeffs[l2], 0.9)\n    self.assertEqual(lc.loss_colocation_ops[l1], out1)\n    self.assertEqual(lc.loss_colocation_ops[l2], out2)\n\n    self.assertEqual(lc.loss_coeffs[l1], 0.1)\n    self.assertEqual(lc.loss_coeffs[l2], 0.9)\n\n  @parameterized.named_parameters(('_EmptyDict', {}),\n                                  ('_PartialDict', {'out2': 0.3}))\n  def testMultipleLossWeights(self, loss_weights):\n    inputs, (out1, out2) = _two_loss_model()\n    model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])\n    loss = ['binary_crossentropy', 'categorical_crossentropy']\n    lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights)\n\n    l1 = lc._loss_dict['sigmoid_cross_entropy_loss'][0]\n    self.assertEqual(lc.loss_coeffs[l1], 1.0)\n\n  @parameterized.named_parameters(\n      ('_MissingDict', {'out2': 'categorical_crossentropy'}),\n      ('_MissingList', ['categorical_crossentropy']),\n      ('_ExtraDict', {'out1': 'binary_crossentropy',\n                      'out2': 'categorical_crossentropy',\n                      'blah': 'mse'}),\n      ('_ExtraList', ['mse', 'binary_crossentropy',\n                      'categorical_crossentropy']),\n      ('_WrongName', {'out1': 'binary_crossentropy',\n                      'path2': 'categorical_crossentropy'}))\n  def testLossErrors(self, loss):\n    with self.assertRaisesRegex(ValueError, '.*loss dict.*'):\n      inputs, (out1, out2) = _two_loss_model()\n      model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])\n      utils.get_layer_collection(model, loss)\n\n  @parameterized.named_parameters(\n      ('_EmptyList', []),\n      ('_MissingList', [0.1]),\n      ('_ExtraList', [0.1, 0.9, 0.3]),\n      ('_ExtraDict', {'out1': 0.1, 'out2': 0.9, 'blahblah': 0.4}),\n      ('_Set', {0.1, 0.3}))\n  def testLossWeightErrors(self, loss_weights):\n    with self.assertRaisesRegex(ValueError, '.*loss_weights.*'):\n      inputs, (out1, out2) = _two_loss_model()\n      model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])\n      loss = ['binary_crossentropy', 'categorical_crossentropy']\n      utils.get_layer_collection(model, loss, loss_weights=loss_weights)\n\n  @parameterized.named_parameters(\n      ('_Seperable', layers.SeparableConv2D(13, 5)),\n      ('_ChannelsFirst', layers.Conv2D(11, 3, data_format='channels_first')))\n  def testInvalidCNNLayers(self, layer):\n    with self.assertRaises(ValueError):\n      model = tf.keras.Sequential([layers.Input(shape=(28, 28, 3)), layer])\n      utils.get_layer_collection(model, 'mse')\n\n  @parameterized.named_parameters(\n      ('_List', ['kron', 'kron_in_diag', 'kron_out_diag', 'kron_both_diag']),\n      ('_Dict', {'l1': 'kron', 'l2': 'kron_in_diag', 'l3': 'kron_out_diag',\n                 'l4': 'kron_both_diag'}),\n      ('_DictOneMissing', {'l2': 'kron_in_diag', 'l3': 'kron_out_diag',\n                           'l4': 'kron_both_diag'}))\n  def testFisherApproxLayerNames(self, fisher_approx):\n    model = tf.keras.Sequential([\n        layers.Dense(10, input_shape=(20,), name='l1'),\n        layers.Activation('relu'),\n        layers.Dense(13, activation='relu', name='l2'),\n        layers.Dense(23, trainable=False),\n        layers.Dense(17, name='l3'),\n        layers.Activation('relu'),\n        layers.Dense(3, name='l4')])\n    lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx)\n    trainable_layers = [model.layers[i] for i in [0, 2, 4, 6]]\n    expected_in_diag_approx = [False, True, False, True]\n    expected_out_diag_approx = [False, False, True, True]\n\n    for layer, in_diag, out_diag in zip(trainable_layers,\n                                        expected_in_diag_approx,\n                                        expected_out_diag_approx):\n      self.assertEqual(\n          in_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_input)\n      self.assertEqual(\n          out_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_output)\n\n  @parameterized.named_parameters(\n      ('_ClassOnly', {layers.Conv2D: 'diagonal'},\n       (fisher_blocks.ConvDiagonalFB, fisher_blocks.ConvDiagonalFB)),\n      ('_NameAndClass', {layers.Conv2D: 'diagonal', 'conv2d_1': None},\n       (fisher_blocks.ConvDiagonalFB, fisher_blocks.ConvKFCBasicFB)))\n  def testFisherApproxLayerClass(self, fisher_approx, block_types):\n    model = _cnn()\n    lc = utils.get_layer_collection(model, 'mse',\n                                    fisher_approx=fisher_approx)\n    trainable_layers = [model.layers[0], model.layers[2]]\n    for layer, block_type in zip(trainable_layers, block_types):\n      self.assertIsInstance(lc.fisher_blocks[layer.weights], block_type)\n\n  @parameterized.named_parameters(\n      ('_EmptyList', []),\n      ('_ExtraDict', {'conv2d': 'diagonal', layers.Conv2D: 'kron',\n                      'UWaterloo': 'kron'}),\n      ('_ExtraList', ['kron', 'diagonal', 'diagonal']),\n      ('_WrongName', {'conv2d': 'kron', 'path2': 'kron'}))\n  def testFisherApproxErrors(self, fisher_approx):\n    with self.assertRaisesRegex(ValueError, '.*fisher_approx.*'):\n      utils.get_layer_collection(_cnn(), 'mse', fisher_approx=fisher_approx)\n\n  @parameterized.named_parameters(\n      ('_List', ['full', 'diagonal'], ['full', 'diagonal']),\n      ('_SerializedDict',\n       {'dense1': 'full', 'dense2': 'diagonal'},\n       {'dense1': 'full', 'dense2': 'diagonal'}),\n      ('_PartiallySerializedDict',\n       {layers.Dense: 'full', utils._CLASS_NAME_PREFIX + 'Conv2D': 'full'},\n       {utils._CLASS_NAME_PREFIX + 'Dense': 'full',\n        utils._CLASS_NAME_PREFIX + 'Conv2D': 'full'}),\n      ('_Dict',\n       {layers.Dense: 'diagonal', layers.Conv2D: 'full'},\n       {utils._CLASS_NAME_PREFIX + 'Dense': 'diagonal',\n        utils._CLASS_NAME_PREFIX + 'Conv2D': 'full'}))\n  def testSerializeFisherApprox(self, approx, correctly_serialized_approx):\n    serialized_approx = utils.serialize_fisher_approx(approx)\n    self.assertEqual(serialized_approx, correctly_serialized_approx)\n\n  def testSeed(self):\n    lc = utils.get_layer_collection(model=_mlp(), loss='mse', seed=4321)\n    self.assertEqual(lc._loss_dict['squared_error_loss'][0]._default_seed, 4321)\n\n  @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False))\n  def testNormalizationLayers(self, has_shift):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 3)),\n        layers.BatchNormalization(center=has_shift, name='bn'),\n        layers.Conv2D(23, 3),\n        layers.LayerNormalization(center=has_shift),\n        layers.GlobalMaxPool2D(),\n    ])\n    fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'}\n    lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx)\n    bn_weights = model.layers[1].trainable_weights\n    ln_weights = model.layers[3].trainable_weights\n    if not has_shift:\n      bn_weights, ln_weights = bn_weights[0], ln_weights[0]\n    bn_block = lc.fisher_blocks[bn_weights]\n    ln_block = lc.fisher_blocks[ln_weights]\n    self.assertIsInstance(bn_block, fisher_blocks.ScaleAndShiftDiagonalFB)\n    self.assertIsInstance(ln_block, fisher_blocks.ScaleAndShiftFullFB)\n    self.assertEqual(bn_block._has_shift, has_shift)\n    self.assertEqual(ln_block._has_shift, has_shift)\n\n  def testErrorWithBatchNormNoScale(self):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 3)),\n        layers.BatchNormalization(scale=False, fused=False),\n        layers.GlobalMaxPool2D(),\n    ])\n    with self.assertRaisesRegex(ValueError, '.*scale=False.*'):\n      utils.get_layer_collection(model, 'binary_crossentropy')\n\n  def testErrorWithLayerNormNoScale(self):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 3)),\n        layers.LayerNormalization(scale=False),\n        layers.GlobalMaxPool2D(),\n    ])\n    with self.assertRaisesRegex(ValueError, '.*scale=False.*'):\n      utils.get_layer_collection(model, 'binary_crossentropy')\n\n  def testNumBatchNormUsesWithPhase(self):\n    tf.keras.backend.set_learning_phase(1)\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 3)),\n        layers.BatchNormalization(fused=True),\n        layers.GlobalMaxPool2D(),\n    ])\n    lc = utils.get_layer_collection(model, 'binary_crossentropy')\n    for w in model.layers[1].trainable_weights:\n      self.assertEqual(lc._vars_to_uses[w], 1)\n\n  def testNumBatchNormUsesNoPhase(self):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5, input_shape=(28, 28, 3)),\n        layers.BatchNormalization(fused=True),\n        layers.GlobalMaxPool2D(),\n    ])\n    lc = utils.get_layer_collection(model, 'binary_crossentropy')\n    for w in model.layers[1].trainable_weights:\n      self.assertEqual(lc._vars_to_uses[w], 2)\n\n  def testModelAsCallable(self):\n    model = tf.keras.Sequential([\n        layers.Conv2D(13, 5),\n        layers.BatchNormalization(name='bn', fused=False),\n        layers.Conv2D(23, 3),\n        layers.LayerNormalization(),\n        layers.GlobalMaxPool2D(),\n    ])\n    inp = tf.random_normal((10, 28, 28, 3))\n    inp = tf.keras.Input(tensor=inp)\n    inp2 = tf.random_normal((10, 28, 28, 3))\n    inp2 = tf.keras.Input(tensor=inp2)\n\n    fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'}\n    _ = model(inp)\n    _ = model(inp2)  # with multiple calls, the latest should be registered.\n    lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx)\n\n    for i in (0, 2):\n      conv_block = lc.fisher_blocks[model.layers[i].trainable_weights]\n      conv_inp = model.layers[i].inbound_nodes[-1].input_tensors\n      conv_out = model.layers[i].inbound_nodes[-1].output_tensors\n      self.assertEqual(conv_inp, conv_block._inputs[0])\n      self.assertEqual(conv_out, conv_block._outputs[0])\n\n  @parameterized.named_parameters(\n      ('_DictApprox', {layers.Dense: 'kron_in_diag',\n                       'l1': 'kron_out_diag',\n                       'l3': 'kron_both_diag'}),\n      ('_ListApprox', ['kron_out_diag', 'kron_in_diag', 'kron_both_diag']))\n  def testNestedModels(self, fisher_approx):\n    # Note this is not a valid trainable model, it was just created to test\n    # order of the dict and list test the DFS order in utils as well.\n    layer1 = layers.Dense(10, input_shape=(1,), name='l1')\n    layer2 = layers.Dense(10, activation='relu', name='l2')\n    layer3 = layers.Dense(10, activation='relu', name='l3')\n\n    inner_model0 = tf.keras.Sequential([layer1])\n\n    inner_model1 = tf.keras.Sequential()\n    inner_model1.add(inner_model0)\n    inner_model1.add(layers.Activation('relu'))\n    inner_model1.add(layer2)\n\n    inner_inp = layers.Input(shape=(1,))\n    x = layer3(inner_inp)\n    x = layers.Reshape(target_shape=(10, 1))(x)\n    x = layers.GlobalMaxPool1D()(x)\n    inner_model2 = tf.keras.Model(inputs=inner_inp, outputs=x)\n\n    inp = layers.Input(shape=(1,))\n    branch1 = inner_model1(inp)\n    branch2 = inner_model2(inp)\n    out = layers.Add()([branch1, branch2])\n    model = tf.keras.Model(inputs=inp, outputs=out)\n\n    lc = utils.get_layer_collection(\n        model=model, loss='mse', fisher_approx=fisher_approx)\n\n    expected_in_diag_approx = [False, True, True]\n    expected_out_diag_approx = [True, False, True]\n    trainable_layers = [layer1, layer2, layer3]\n    for layer, in_diag, out_diag in zip(trainable_layers,\n                                        expected_in_diag_approx,\n                                        expected_out_diag_approx):\n      self.assertIsInstance(lc.fisher_blocks[layer.weights],\n                            fisher_blocks.FullyConnectedKFACBasicFB)\n      self.assertEqual(\n          in_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_input)\n      self.assertEqual(\n          out_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_output)\n\n  def testMultiOutputNestedModelFails(self):\n    inp = tf.keras.Input(shape=(1,))\n    out1 = layers.Dense(1)(inp)\n    out2 = layers.Dense(1)(inp)\n    model = tf.keras.Model(inputs=inp, outputs=[out1, out2])\n\n    inp2 = tf.keras.Input(shape=(1,))\n    out = model(inp2)\n    model = tf.keras.Model(inputs=inp2, outputs=out)\n\n    with self.assertRaisesRegex(\n        ValueError, 'Nested models with multiple outputs are unsupported.'):\n      utils.get_layer_collection(model, loss=['mse', 'mse'])\n\n\nclass SerializeLossTest(tf.test.TestCase, parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      ('_String', 'binary_crossentropy', 'binary_crossentropy'),\n      ('_KerasLoss', losses.binary_crossentropy, 'binary_crossentropy'),\n      ('_Dict',\n       {'out1': 'binary_crossentropy', 'out2': losses.mean_squared_error},\n       {'out1': 'binary_crossentropy', 'out2': 'mean_squared_error'}),\n      ('_List',\n       ['mse', tf.keras.losses.categorical_crossentropy],\n       ['mse', 'categorical_crossentropy']))\n  def testSerializeLoss(self, loss, correctly_serialized_loss):\n    serialized_loss = utils.serialize_loss(loss)\n    self.assertEqual(serialized_loss, correctly_serialized_loss)\n\n\nclass GetLossFnTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    super(GetLossFnTest, self).setUp()\n    tf.reset_default_graph()\n    tf.random.set_random_seed(_SEED)\n\n  @parameterized.parameters(\n      ('categorical_crossentropy', (11, 10), True, True),\n      ('sparse_categorical_crossentropy', (11,), True, False),\n      ('categorical_crossentropy', (11, 10), False, True),\n      ('sparse_categorical_crossentropy', (11,), False, False),\n      (losses.CategoricalCrossentropy(), (11, 10), True, True),\n      (losses.categorical_crossentropy, (11, 10), False, True))\n  def testCrossEntropy(self, loss, label_shape, is_logits, use_regularization):\n    conv_kwargs = {'kernel_regularizer': 'l2'} if use_regularization else {}\n    model_layers = [\n        layers.Conv2D(7, 5, input_shape=(32, 32, 3), **conv_kwargs),\n        layers.Activation('relu'),\n        layers.Conv2D(10, (3, 3), activation='relu', **conv_kwargs),\n        layers.GlobalMaxPool2D()\n    ]\n    if is_logits:\n      model_layers.append(layers.Activation('softmax'))\n    model = tf.keras.Sequential(model_layers)\n    model.compile('sgd', loss)\n    loss_fn = utils.get_loss_fn(model=model, loss=loss)\n\n    x = tf.constant(np.random.random((11, 32, 32, 3)).astype(np.float32))\n    y = tf.constant(np.random.random(label_shape).astype(np.float32))\n    model_loss = model.evaluate(x, y, steps=1)\n    fn_loss = tf.keras.backend.get_value(loss_fn((x, y)))\n    fn_loss_w_pred = tf.keras.backend.get_value(\n        loss_fn((x, y), prediction=model(x)))\n    self.assertAlmostEqual(model_loss, fn_loss, places=5)\n    self.assertAlmostEqual(fn_loss, fn_loss_w_pred, places=5)\n\n    model.train_on_batch(np.random.random((11, 32, 32, 3)),\n                         np.random.random(label_shape))\n\n    x = tf.constant(np.random.random((11, 32, 32, 3)).astype(np.float32))\n    y = tf.constant(np.random.random(label_shape).astype(np.float32))\n    model_loss = model.test_on_batch(x, y)\n    fn_loss = tf.keras.backend.get_value(loss_fn((x, y)))\n    fn_loss_w_pred = tf.keras.backend.get_value(\n        loss_fn((x, y), prediction=model(x)))\n    self.assertAlmostEqual(model_loss, fn_loss, places=5)\n    self.assertAlmostEqual(fn_loss, fn_loss_w_pred, places=5)\n\n  @parameterized.parameters('categorical_crossentropy',\n                            losses.CategoricalCrossentropy(),\n                            losses.CategoricalCrossentropy(from_logits=False),\n                            losses.categorical_crossentropy)\n  def testCrossEntropyCustomLoop(self, loss):\n    model_layers = [\n        layers.Conv2D(7, 5, input_shape=(32, 32, 3)),\n        layers.Activation('relu'),\n        layers.Conv2D(10, (3, 3), kernel_regularizer='l2'),\n        layers.GlobalMaxPool2D()\n    ]\n    model = tf.keras.Sequential(model_layers)\n    model.compile('sgd', loss)\n    loss_fn = utils.get_loss_fn(model=model, loss=loss)\n\n    x = np.random.random((11, 32, 32, 3)).astype(np.float32)\n    y = np.random.random((11, 10)).astype(np.float32)\n    tf_x = tf.constant(x)\n    tf_y = tf.constant(y)\n\n    with tf.Session() as sess:\n      sess.run(tf.global_variables_initializer())\n      model_loss = sess.run(\n          model.total_loss,\n          feed_dict={'conv2d_input:0': x, 'global_max_pooling2d_target:0': y})\n      fn_loss = sess.run(loss_fn((tf_x, tf_y)))\n      fn_loss_w_pred = sess.run(loss_fn((tf_x, tf_y), prediction=model(tf_x)))\n    self.assertAlmostEqual(model_loss, fn_loss, fn_loss_w_pred)\n\n  @parameterized.parameters(\n      'mse', 'MSE', 'mean_squared_error', losses.mean_squared_error)\n  def testMSE(self, loss):\n    model = _mlp()\n    model.compile('sgd', loss)\n    loss_fn = utils.get_loss_fn(model=model, loss=loss)\n\n    x = tf.constant(np.random.random((23, 1)).astype(np.float32))\n    y = tf.constant(np.random.random((23, 1)).astype(np.float32))\n    model_loss = model.test_on_batch(x, y)\n    fn_loss = tf.keras.backend.get_value(loss_fn((x, y)))\n    fn_loss_w_pred = tf.keras.backend.get_value(\n        loss_fn((x, y), prediction=model(x)))\n    self.assertAlmostEqual(model_loss, fn_loss, fn_loss_w_pred)\n\n  @parameterized.parameters(\n      ({'out1': 'mse', 'out2': losses.categorical_crossentropy},\n       [0.3, 0.7]),\n      (['categorical_crossentropy', losses.MeanSquaredError()],\n       {'out2': 0.1}))\n  def testMultiLoss(self, multi_loss, loss_weights):\n    inps, outs = _two_loss_model()\n    model = tf.keras.Model(inputs=inps, outputs=outs)\n    model.compile('sgd', multi_loss, loss_weights=loss_weights)\n    loss_fn = utils.get_loss_fn(\n        model=model, loss=multi_loss, loss_weights=loss_weights)\n\n    x = tf.constant(np.random.random((11, 28, 28, 1)).astype(np.float32))\n    y_1 = tf.constant(np.random.random((11, 1)).astype(np.float32))\n    y_2 = tf.constant(np.random.random((11, 9)).astype(np.float32))\n    # test_on_batch returns the total loss and the two individual losses.\n    # We just want the total, so we use model_loss[0].\n    model_loss = model.test_on_batch(x, [y_1, y_2])[0]\n    fn_loss = tf.keras.backend.get_value(loss_fn((x, [y_1, y_2])))\n    fn_loss_w_pred = tf.keras.backend.get_value(\n        loss_fn((x, [y_1, y_2]), prediction=model(x)))\n    self.assertAlmostEqual(model_loss, fn_loss, fn_loss_w_pred)\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/layer_collection_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for kfac.layer_collection.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import fisher_blocks\nfrom kfac.python.ops import fisher_factors\nfrom kfac.python.ops import layer_collection\n\n\nclass MockFisherBlock(object):\n  \"\"\"A fake FisherBlock.\"\"\"\n\n  num_registered_towers = 2\n\n  def __init__(self, name='MockFisherBlock'):\n    self.name = name\n\n  def __eq__(self, other):\n    return isinstance(other, MockFisherBlock) and other.name == self.name\n\n  def __hash__(self):\n    return hash(self.name)\n\n\nclass LayerParametersDictTest(tf.test.TestCase):\n\n  def testSetItem(self):\n    \"\"\"Ensure insertion, contains, retrieval works for supported key types.\"\"\"\n    with tf.Graph().as_default():\n      lp_dict = layer_collection.LayerParametersDict()\n\n      x = tf.constant(0)\n      y0 = tf.constant(0)\n      y1 = tf.constant(0)\n      z0 = tf.constant(0)\n      z1 = tf.constant(0)\n      keys = [x, (y0, y1), [z0, z1]]\n      for key in keys:\n        lp_dict[key] = key\n\n      for key in keys:\n        self.assertTrue(key in lp_dict)\n        self.assertEqual(lp_dict[key], key)\n\n  def testSetItemOverlap(self):\n    \"\"\"Ensure insertion fails if key overlaps with existing key.\"\"\"\n    with tf.Graph().as_default():\n      lp_dict = layer_collection.LayerParametersDict()\n\n      x = tf.constant(0)\n      y = tf.constant(0)\n      lp_dict[x] = 'value'\n\n      with self.assertRaises(ValueError):\n        lp_dict[(x, y)] = 'value'\n\n      # Ensure 'y' wasn't inserted.\n      self.assertTrue(x in lp_dict)\n      self.assertFalse(y in lp_dict)\n\n\nclass LayerCollectionTest(tf.test.TestCase):\n\n  def testLayerCollectionInit(self):\n    lc = layer_collection.LayerCollection()\n    self.assertEqual(0, len(lc.get_blocks()))\n    self.assertEqual(0, len(lc.get_factors()))\n    self.assertFalse(lc.losses)\n\n  def testRegisterBlocks(self):\n    with tf.Graph().as_default():\n      tf.set_random_seed(200)\n      lc = layer_collection.LayerCollection()\n      lc.register_fully_connected(\n          tf.constant(1), tf.constant(2), tf.constant(3))\n      lc.register_fully_connected(\n          tf.constant(1),\n          tf.constant(2),\n          tf.constant(3),\n          approx=layer_collection.APPROX_DIAGONAL_NAME)\n      lc.register_conv2d(\n          params=tf.ones((2, 3, 4, 5)),\n          strides=[1, 1, 1, 1],\n          padding='SAME',\n          inputs=tf.ones((1, 2, 3, 4)),\n          outputs=tf.ones((1, 1, 1, 5)))\n      lc.register_conv2d(\n          params=tf.ones((2, 3, 4, 5)),\n          strides=[1, 1, 1, 1],\n          padding='SAME',\n          inputs=tf.ones((1, 2, 3, 4)),\n          outputs=tf.ones((1, 1, 1, 5)),\n          approx=layer_collection.APPROX_DIAGONAL_NAME)\n      lc.register_separable_conv2d(\n          depthwise_params=tf.ones((3, 3, 1, 2)),\n          pointwise_params=tf.ones((1, 1, 2, 4)),\n          inputs=tf.ones((32, 5, 5, 1)),\n          depthwise_outputs=tf.ones((32, 5, 5, 2)),\n          pointwise_outputs=tf.ones((32, 5, 5, 4)),\n          strides=[1, 1, 1, 1],\n          padding='SAME')\n      lc.register_convolution(\n          params=tf.ones((3, 3, 1, 8)),\n          inputs=tf.ones((32, 5, 5, 1)),\n          outputs=tf.ones((32, 5, 5, 8)),\n          padding='SAME')\n      lc.register_generic(\n          tf.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)\n      lc.register_generic(\n          tf.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME)\n      lc.register_fully_connected_multi(\n          tf.constant(1), (tf.constant(2), tf.constant(3)),\n          (tf.constant(4), tf.constant(5)))\n      lc.register_conv2d_multi(\n          params=tf.ones((2, 3, 4, 5)),\n          strides=[1, 1, 1, 1],\n          padding='SAME',\n          inputs=(tf.ones((1, 2, 3, 4)), tf.ones((5, 6, 7, 8))),\n          outputs=(tf.ones((1, 1, 1, 5)), tf.ones((2, 2, 2, 10))))\n      lc.register_fully_connected_multi(\n          tf.constant((1,)), (tf.constant(2), tf.constant(3)),\n          (tf.constant(4), tf.constant(5)),\n          approx=layer_collection.APPROX_KRONECKER_INDEP_IN_DIAG_NAME)\n      lc.register_fully_connected_multi(\n          tf.constant((1,)), (tf.constant(2), tf.constant(3)),\n          (tf.constant(4), tf.constant(5)),\n          dense_inputs=False,\n          approx=layer_collection.APPROX_KRONECKER_INDEP_IN_DIAG_NAME)\n\n      self.assertEqual(13, len(lc.get_blocks()))\n\n  def testRegisterBlocksMultipleRegistrations(self):\n    with tf.Graph().as_default():\n      tf.set_random_seed(200)\n      lc = layer_collection.LayerCollection()\n      key = tf.constant(1)\n      lc.register_fully_connected(key, tf.constant(2), tf.constant(3))\n      with self.assertRaises(ValueError) as cm:\n        lc.register_generic(key, 16)\n      self.assertIn('already in LayerCollection', str(cm.exception))\n\n  def testRegisterSingleParamNotRegistered(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {tf.get_variable('y', initializer=tf.constant(1,)): '1'}\n    lc._register_block(x, 'foo')\n\n  def testShouldRegisterSingleParamRegistered(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {x: '1'}\n    with self.assertRaises(ValueError) as cm:\n      lc._register_block(x, 'foo')\n    self.assertIn('already in LayerCollection', str(cm.exception))\n\n  def testRegisterSingleParamRegisteredInTuple(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    y = tf.get_variable('y', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {(x, y): '1'}\n    with self.assertRaises(ValueError) as cm:\n      lc._register_block(x, 'foo')\n    self.assertIn('was already registered', str(cm.exception))\n\n  def testRegisterTupleParamNotRegistered(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    y = tf.get_variable('y', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {tf.get_variable('z', initializer=tf.constant(1,)): '1'}\n\n    lc._register_block((x, y), 'foo')\n    self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))\n\n  def testRegisterTupleParamRegistered(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    y = tf.get_variable('y', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {(x, y): '1'}\n\n    with self.assertRaises(ValueError) as cm:\n      lc._register_block((x, y), 'foo')\n    self.assertIn('already in LayerCollection', str(cm.exception))\n\n  def testRegisterTupleParamRegisteredInSuperset(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    y = tf.get_variable('y', initializer=tf.constant(1,))\n    z = tf.get_variable('z', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {(x, y, z): '1'}\n\n    with self.assertRaises(ValueError) as cm:\n      lc._register_block((x, y), 'foo')\n    self.assertIn('was already registered', str(cm.exception))\n\n  def testRegisterTupleParamSomeRegistered(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    y = tf.get_variable('y', initializer=tf.constant(1,))\n    z = tf.get_variable('z', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}\n\n    with self.assertRaises(ValueError) as cm:\n      lc._register_block((x, y), MockFisherBlock('foo'))\n    self.assertIn('was already registered', str(cm.exception))\n\n  def testRegisterTupleVarSomeRegisteredInOtherTuples(self):\n    x = tf.get_variable('x', initializer=tf.constant(1,))\n    y = tf.get_variable('y', initializer=tf.constant(1,))\n    z = tf.get_variable('z', initializer=tf.constant(1,))\n    w = tf.get_variable('w', initializer=tf.constant(1,))\n    lc = layer_collection.LayerCollection()\n    lc.fisher_blocks = {(x, z): '1', (z, w): '2'}\n\n    with self.assertRaises(ValueError) as cm:\n      lc._register_block((x, y), 'foo')\n    self.assertIn('was already registered', str(cm.exception))\n\n  def testRegisterCategoricalPredictiveDistribution(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      logits = tf.eye(2)\n\n      lc = layer_collection.LayerCollection()\n      lc.register_categorical_predictive_distribution(logits, seed=200)\n      single_loss = sess.run(lc.total_sampled_loss())\n\n      lc2 = layer_collection.LayerCollection()\n      lc2.register_categorical_predictive_distribution(logits, seed=200)\n      lc2.register_categorical_predictive_distribution(logits, seed=200)\n      double_loss = sess.run(lc2.total_sampled_loss())\n      self.assertAlmostEqual(2 * single_loss, double_loss)\n\n  def testLossFunctionByName(self):\n    \"\"\"Ensure loss functions can be identified by name.\"\"\"\n    with tf.Graph().as_default():\n      logits = tf.eye(2)\n      lc = layer_collection.LayerCollection()\n\n      # Create a new loss function by name.\n      lc.register_categorical_predictive_distribution(logits, name='loss1')\n      self.assertEqual(1, len(lc.towers_by_loss))\n\n      # Add logits to same loss function.\n      lc.register_categorical_predictive_distribution(\n          logits, name='loss1', reuse=True)\n      self.assertEqual(1, len(lc.towers_by_loss))\n\n      # Add another new loss function.\n      lc.register_categorical_predictive_distribution(logits, name='loss2')\n      self.assertEqual(2, len(lc.towers_by_loss))\n\n  def testLossFunctionWithoutName(self):\n    \"\"\"Ensure loss functions get unique names if 'name' not specified.\"\"\"\n    with tf.Graph().as_default():\n      logits = tf.eye(2)\n      lc = layer_collection.LayerCollection()\n\n      # Create a new loss function with default names.\n      lc.register_categorical_predictive_distribution(logits)\n      lc.register_categorical_predictive_distribution(logits)\n      self.assertEqual(2, len(lc.losses))\n\n  def testCategoricalPredictiveDistributionMultipleMinibatches(self):\n    \"\"\"Ensure multiple minibatches are registered.\"\"\"\n    with tf.Graph().as_default():\n      batch_size = 3\n      output_size = 2\n      logits = tf.zeros([batch_size, output_size])\n      targets = tf.ones([batch_size], dtype=tf.int32)\n      lc = layer_collection.LayerCollection()\n\n      # Create a new loss function.\n      lc.register_categorical_predictive_distribution(\n          logits, targets=targets, name='loss1')\n\n      # Can add when reuse=True\n      lc.register_categorical_predictive_distribution(\n          logits, targets=targets, name='loss1', reuse=True)\n\n      # Can add when reuse=VARIABLE_SCOPE and reuse=True there.\n      with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n        lc.register_categorical_predictive_distribution(\n            logits,\n            targets=targets,\n            name='loss1',\n            reuse=layer_collection.VARIABLE_SCOPE)\n\n      # Can't add when reuse=False\n      with self.assertRaises(KeyError):\n        lc.register_categorical_predictive_distribution(\n            logits, targets=targets, name='loss1', reuse=False)\n\n      # Can't add when reuse=VARIABLE_SCOPE and reuse=False there.\n      with self.assertRaises(KeyError):\n        lc.register_categorical_predictive_distribution(\n            logits,\n            targets=targets,\n            name='loss1',\n            reuse=layer_collection.VARIABLE_SCOPE)\n\n      self.assertEqual(len(lc.towers_by_loss), 1)\n      # Three successful registrations.\n      self.assertEqual(len(lc.towers_by_loss[0]), 3)\n\n  def testRegisterCategoricalPredictiveDistributionBatchSize1(self):\n    with tf.Graph().as_default():\n      tf.set_random_seed(200)\n      logits = tf.random_normal((1, 2))\n      lc = layer_collection.LayerCollection()\n\n      lc.register_categorical_predictive_distribution(logits, seed=200)\n\n  def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      logits = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)\n      lc = layer_collection.LayerCollection()\n      targets = tf.constant([0, 1], dtype=tf.int32)\n\n      lc.register_categorical_predictive_distribution(logits, targets=targets)\n      single_loss = sess.run(lc.total_loss())\n      self.assertAlmostEqual(1.6265233, single_loss)\n\n  def testRegisterNormalPredictiveDistribution(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      predictions = tf.constant([[1., 2.], [3., 4]], dtype=tf.float32)\n\n      lc = layer_collection.LayerCollection()\n      lc.register_normal_predictive_distribution(predictions, 1., seed=200)\n      single_loss = sess.run(lc.total_sampled_loss())\n\n      lc2 = layer_collection.LayerCollection()\n      lc2.register_normal_predictive_distribution(predictions, 1., seed=200)\n      lc2.register_normal_predictive_distribution(predictions, 1., seed=200)\n      double_loss = sess.run(lc2.total_sampled_loss())\n\n      self.assertAlmostEqual(2 * single_loss, double_loss)\n\n  def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      predictions = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)\n      lc = layer_collection.LayerCollection()\n      targets = tf.constant([[3., 1.], [4., 2.]], dtype=tf.float32)\n\n      lc.register_normal_predictive_distribution(\n          predictions, 2.**2, targets=targets)\n      single_loss = sess.run(lc.total_loss())\n      self.assertAlmostEqual(7.6983433, single_loss)\n\n  def ensureLayerReuseWorks(self, register_fn):\n    \"\"\"Ensure the 'reuse' keyword argument function as intended.\n\n    Args:\n      register_fn: function for registering a layer. Arguments are\n        layer_collection, reuse, and approx.\n    \"\"\"\n    # Fails on second if reuse=False.\n    lc = layer_collection.LayerCollection()\n    register_fn(lc)\n    with self.assertRaises(ValueError):\n      register_fn(lc, reuse=False)\n\n    # Succeeds on second if reuse=True.\n    lc = layer_collection.LayerCollection()\n    register_fn(lc)\n    register_fn(lc, reuse=True)\n\n    # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.\n    lc = layer_collection.LayerCollection()\n    register_fn(lc)\n    with self.assertRaises(ValueError):\n      register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)\n\n    # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.\n    lc = layer_collection.LayerCollection()\n    register_fn(lc)\n    with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n      register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)\n\n    # Fails if block type changes.\n    lc = layer_collection.LayerCollection()\n    register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)\n    with self.assertRaises(ValueError):\n      register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)\n\n    # Fails if reuse requested but no FisherBlock exists.\n    lc = layer_collection.LayerCollection()\n    with self.assertRaises(ValueError):\n      register_fn(lc, reuse=True)\n\n  def testRegisterFullyConnectedReuse(self):\n    \"\"\"Ensure the 'reuse' works with register_fully_connected.\"\"\"\n    with tf.Graph().as_default():\n      inputs = tf.ones([2, 10])\n      outputs = tf.zeros([2, 5])\n      params = (\n          tf.get_variable('w', [10, 5]),  #\n          tf.get_variable('b', [5]))\n\n      def register_fn(lc, **kwargs):\n        lc.register_fully_connected(\n            params=params, inputs=inputs, outputs=outputs, **kwargs)\n\n      self.ensureLayerReuseWorks(register_fn)\n\n  def testRegisterConv2dReuse(self):\n    \"\"\"Ensure the 'reuse' works with register_conv2d.\"\"\"\n    with tf.Graph().as_default():\n      inputs = tf.ones([2, 5, 5, 10])\n      outputs = tf.zeros([2, 5, 5, 3])\n      params = (\n          tf.get_variable('w', [1, 1, 10, 3]),  #\n          tf.get_variable('b', [3]))\n\n      def register_fn(lc, **kwargs):\n        lc.register_conv2d(\n            params=params,\n            strides=[1, 1, 1, 1],\n            padding='SAME',\n            inputs=inputs,\n            outputs=outputs,\n            **kwargs)\n\n      self.ensureLayerReuseWorks(register_fn)\n\n  def testReuseWithInvalidRegistration(self):\n    \"\"\"Invalid registrations shouldn't overwrite existing blocks.\"\"\"\n    with tf.Graph().as_default():\n      inputs = tf.ones([2, 5, 5, 10])\n      outputs = tf.zeros([2, 5, 5, 3])\n      w = tf.get_variable('w', [1, 1, 10, 3])\n      b = tf.get_variable('b', [3])\n      lc = layer_collection.LayerCollection()\n      lc.register_fully_connected(w, inputs, outputs)\n      self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)\n      with self.assertRaises(ValueError):\n        lc.register_fully_connected((w, b), inputs, outputs, reuse=True)\n      self.assertNotIn((w, b), lc.fisher_blocks)\n      self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)\n      lc.register_fully_connected(w, inputs, outputs, reuse=True)\n      self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)\n\n  def testMakeOrGetFactor(self):\n    with tf.Graph().as_default():\n      tf.set_random_seed(200)\n      lc = layer_collection.LayerCollection()\n      key = tf.constant(1)\n      lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16))\n      lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16))\n      lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((tf.constant(2),), 16))\n\n      self.assertEqual(2, len(lc.get_factors()))\n      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n      self.assertTrue(\n          all([var.name.startswith('LayerCollection') for var in variables]))\n\n  def testMakeOrGetFactorCustomScope(self):\n    with tf.Graph().as_default():\n      tf.set_random_seed(200)\n      scope = 'Foo'\n      lc = layer_collection.LayerCollection(name=scope)\n      key = tf.constant(1)\n      lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16))\n      lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16))\n      lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((tf.constant(2),), 16))\n\n      self.assertEqual(2, len(lc.get_factors()))\n      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n      self.assertTrue(all([var.name.startswith(scope) for var in variables]))\n\n  def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):\n    x = tf.get_variable('x', shape=())\n    y = tf.get_variable('y', shape=())\n    z = tf.get_variable('z', shape=())\n    lc = layer_collection.LayerCollection()\n    lc.define_linked_parameters((x, y))\n\n    with self.assertRaises(ValueError):\n      lc.define_linked_parameters((x, z))\n\n  def testIdentifySubsetPreviouslyRegisteredTensor(self):\n    x = tf.get_variable('x', shape=())\n    y = tf.get_variable('y', shape=())\n    lc = layer_collection.LayerCollection()\n    lc.define_linked_parameters((x, y))\n\n    with self.assertRaises(ValueError):\n      lc.define_linked_parameters(x)\n\n  def testSpecifyApproximation(self):\n    w_0 = tf.get_variable('w_0', [10, 10])\n    w_1 = tf.get_variable('w_1', [10, 10])\n\n    b_0 = tf.get_variable('b_0', [10])\n    b_1 = tf.get_variable('b_1', [10])\n\n    x_0 = tf.placeholder(tf.float32, shape=(32, 10))\n    x_1 = tf.placeholder(tf.float32, shape=(32, 10))\n\n    pre_bias_0 = tf.matmul(x_0, w_0)\n    pre_bias_1 = tf.matmul(x_1, w_1)\n\n    # Build the fully connected layers in the graph.\n    pre_bias_0 + b_0  # pylint: disable=pointless-statement\n    pre_bias_1 + b_1  # pylint: disable=pointless-statement\n\n    lc = layer_collection.LayerCollection()\n    lc.define_linked_parameters(\n        w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)\n    lc.define_linked_parameters(\n        w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)\n    lc.define_linked_parameters(\n        b_0, approximation=layer_collection.APPROX_FULL_NAME)\n    lc.define_linked_parameters(\n        b_1, approximation=layer_collection.APPROX_FULL_NAME)\n\n    lc.register_fully_connected(w_0, x_0, pre_bias_0)\n    lc.register_fully_connected(\n        w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)\n    self.assertIsInstance(lc.fisher_blocks[w_0],\n                          fisher_blocks.FullyConnectedDiagonalFB)\n    self.assertIsInstance(lc.fisher_blocks[w_1],\n                          fisher_blocks.FullyConnectedKFACBasicFB)\n\n    lc.register_generic(b_0, batch_size=1)\n    lc.register_generic(\n        b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)\n    self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.NaiveFullFB)\n    self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)\n\n  def testDefaultLayerCollection(self):\n    with tf.Graph().as_default():\n      # Can't get default if there isn't one set.\n      with self.assertRaises(ValueError):\n        layer_collection.get_default_layer_collection()\n\n      # Can't set default twice.\n      lc = layer_collection.LayerCollection()\n      layer_collection.set_default_layer_collection(lc)\n      with self.assertRaises(ValueError):\n        layer_collection.set_default_layer_collection(lc)\n\n      # Same as one set.\n      self.assertTrue(lc is layer_collection.get_default_layer_collection())\n\n      # Can set to None.\n      layer_collection.set_default_layer_collection(None)\n      with self.assertRaises(ValueError):\n        layer_collection.get_default_layer_collection()\n\n      # as_default() is the same as setting/clearing.\n      with lc.as_default():\n        self.assertTrue(lc is layer_collection.get_default_layer_collection())\n      with self.assertRaises(ValueError):\n        layer_collection.get_default_layer_collection()\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/loss_functions_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for kfac.loss_functions.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import loss_functions\n\n\nclass InsertSliceInZerosTest(tf.test.TestCase):\n\n  def testBadShape(self):\n    bad_shaped_ones = tf.ones(shape=[1, 3])  # n.b. shape[1] != 1\n    with self.assertRaises(ValueError):\n      loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)\n\n  def test3d(self):\n    input_tensor = tf.constant([[[1, 2]], [[3, 4]]])\n    expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]\n    op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)\n    with self.test_session() as sess:\n      actual_output_array = sess.run(op)\n    self.assertAllEqual(expected_output_array, actual_output_array)\n\n\nclass CategoricalLogitsNegativeLogProbLossTest(tf.test.TestCase):\n\n  def testSample(self):\n    \"\"\"Ensure samples can be drawn.\"\"\"\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.asarray([\n          [0., 0., 0.],  #\n          [1., -1., 0.]\n      ]).astype(np.float32)\n      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(\n          tf.constant(logits))\n      sample = loss.sample(42)\n      sample = sess.run(sample)\n      self.assertEqual(sample.shape, (2,))\n\n  def testEvaluateOnTargets(self):\n    \"\"\"Ensure log probability can be evaluated correctly.\"\"\"\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.asarray([\n          [0., 0., 0.],  #\n          [1., -1., 0.]\n      ]).astype(np.float32)\n      targets = np.asarray([2, 1]).astype(np.int32)\n      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(\n          tf.constant(logits), targets=tf.constant(targets))\n      neg_log_prob = loss.evaluate()\n      neg_log_prob = sess.run(neg_log_prob)\n\n      # Calculate explicit log probability of targets.\n      probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)\n      log_probs = np.log([\n          probs[0, targets[0]],  #\n          probs[1, targets[1]]\n      ])\n      expected_log_prob = np.sum(log_probs)\n\n      self.assertAllClose(neg_log_prob, -expected_log_prob)\n\n  def testEvaluateOnSample(self):\n    \"\"\"Ensure log probability of a sample can be drawn.\"\"\"\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.asarray([\n          [0., 0., 0.],  #\n          [1., -1., 0.]\n      ]).astype(np.float32)\n      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(\n          tf.constant(logits))\n      neg_log_prob = loss.evaluate_on_sample(42)\n\n      # Simply ensure this doesn't crash. As the output is random, it's\n      # difficult to say if the output is correct or not...\n      neg_log_prob = sess.run(neg_log_prob)\n\n  def testMultiplyFisherSingleVector(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.array([1., 2., 3.])\n      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)\n\n      # the LossFunction.multiply_fisher docstring only says it supports the\n      # case where the vector is the same shape as the input natural parameters\n      # (i.e. the logits here), but here we also test leading dimensions\n      vector = np.array([1., 2., 3.])\n      vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]\n\n      probs = np.exp(logits - np.logaddexp.reduce(logits))\n      fisher = np.diag(probs) - np.outer(probs, probs)\n\n      for vector in vectors:\n        result = loss.multiply_fisher(vector)\n        expected_result = np.dot(vector, fisher)\n        self.assertAllClose(expected_result, sess.run(result))\n\n  def testMultiplyFisherBatch(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.array([[1., 2., 3.], [4., 6., 8.]])\n      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)\n\n      vector = np.array([[1., 2., 3.], [5., 3., 1.]])\n\n      na = np.newaxis\n      probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,\n                                                  keepdims=True))\n      fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]\n\n      result = loss.multiply_fisher(vector)\n      expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]\n      self.assertEqual(sess.run(result).shape, logits.shape)\n      self.assertAllClose(expected_result, sess.run(result))\n\n\nclass OnehotCategoricalLogitsNegativeLogProbLossTest(tf.test.TestCase):\n\n  def testSample(self):\n    \"\"\"Ensure samples can be drawn.\"\"\"\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.asarray([\n          [0., 0., 0.],  #\n          [1., -1., 0.]\n      ]).astype(np.float32)\n      loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(\n          tf.constant(logits))\n      sample = loss.sample(42)\n      sample = sess.run(sample)\n      self.assertEqual(sample.shape, (2, 3))\n\n  def testEvaluateOnTargets(self):\n    \"\"\"Ensure log probability can be evaluated correctly.\"\"\"\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.asarray([\n          [0., 0., 0.],  #\n          [1., -1., 0.]\n      ]).astype(np.float32)\n      targets = np.asarray([2, 1]).astype(np.int32)\n      loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(\n          tf.constant(logits), targets=tf.one_hot(targets, 3))\n      neg_log_prob = loss.evaluate()\n      neg_log_prob = sess.run(neg_log_prob)\n\n      # Calculate explicit log probability of targets.\n      probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)\n      log_probs = np.log([\n          probs[0, targets[0]],  #\n          probs[1, targets[1]]\n      ])\n      expected_log_prob = np.sum(log_probs)\n\n      self.assertAllClose(neg_log_prob, -expected_log_prob)\n\n  def testEvaluateOnSample(self):\n    \"\"\"Ensure log probability of a sample can be drawn.\"\"\"\n    with tf.Graph().as_default(), self.test_session() as sess:\n      logits = np.asarray([\n          [0., 0., 0.],  #\n          [1., -1., 0.]\n      ]).astype(np.float32)\n      loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(\n          tf.constant(logits))\n      neg_log_prob = loss.evaluate_on_sample(42)\n\n      # Simply ensure this doesn't crash. As the output is random, it's\n      # difficult to say if the output is correct or not...\n      neg_log_prob = sess.run(neg_log_prob)\n\nif __name__ == \"__main__\":\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/op_queue_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for kfac.op_queue.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import op_queue\n\n\nclass OpQueueTest(tf.test.TestCase):\n\n  def testNextOp(self):\n    \"\"\"Ensures all ops get selected eventually.\"\"\"\n    with tf.Graph().as_default():\n      ops = [\n          tf.add(1, 2),\n          tf.subtract(1, 2),\n          tf.reduce_mean([1, 2]),\n      ]\n      queue = op_queue.OpQueue(ops, seed=0)\n\n      with self.test_session() as sess:\n        # Ensure every inv update op gets selected.\n        selected_ops = set([queue.next_op(sess) for _ in ops])\n        self.assertEqual(set(ops), set(selected_ops))\n\n        # Ensure additional calls don't create any new ops.\n        selected_ops.add(queue.next_op(sess))\n        self.assertEqual(set(ops), set(selected_ops))\n\n\nif __name__ == \"__main__\":\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/optimizer_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for kfac.optimizer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import fisher_factors as ff\nfrom kfac.python.ops import layer_collection as lc\nfrom kfac.python.ops import optimizer\n\n\ndef dummy_layer_collection():\n  lcoll = lc.LayerCollection()\n  dummy = tf.constant([1., 2.])\n  lcoll.register_categorical_predictive_distribution(logits=dummy)\n  return lcoll\n\n\nclass OptimizerTest(tf.test.TestCase):\n\n  def testOptimizerInitInvalidMomentumRegistration(self):\n    with self.assertRaises(ValueError):\n      optimizer.KfacOptimizer(\n          0.1, 0.2, lc.LayerCollection(), 0.3, momentum_type='foo')\n\n  def testOptimizerInit(self):\n    with tf.Graph().as_default():\n      layer_collection = lc.LayerCollection()\n\n      inputs = tf.ones((2, 1)) * 2\n      weights_val = np.ones((1, 1), dtype=np.float32) * 3.\n      weights = tf.get_variable('w', initializer=tf.constant(weights_val))\n      bias = tf.get_variable(\n          'b', initializer=tf.zeros_initializer(), shape=(1, 1))\n      output = tf.matmul(inputs, weights) + bias\n\n      layer_collection.register_fully_connected((weights, bias), inputs, output)\n\n      logits = tf.tanh(output)\n      targets = tf.constant([[0.], [1.]])\n      output = tf.reduce_mean(\n          tf.nn.softmax_cross_entropy_with_logits(\n              logits=logits, labels=targets))\n\n      layer_collection.register_categorical_predictive_distribution(logits)\n\n      optimizer.KfacOptimizer(\n          0.1,\n          0.2,\n          layer_collection,\n          0.3,\n          momentum=0.5,\n          momentum_type='regular')\n\n  def testSquaredFisherNorm(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      grads_and_vars = [(tf.constant([[1., 2.], [3., 4.]]), None),\n                        (tf.constant([[2., 3.], [4., 5.]]), None)]\n      pgrads_and_vars = [(tf.constant([[3., 4.], [5., 6.]]), None),\n                         (tf.constant([[7., 8.], [9., 10.]]), None)]\n      opt = optimizer.KfacOptimizer(0.1, 0.2, dummy_layer_collection(), 0.3)\n      sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)\n      self.assertAlmostEqual(174., sess.run(sq_norm), places=5)\n\n  def testUpdateClipCoeff(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      grads_and_vars = [(tf.constant([[1., 2.], [3., 4.]]), None),\n                        (tf.constant([[2., 3.], [4., 5.]]), None)]\n      pgrads_and_vars = [(tf.constant([[3., 4.], [5., 6.]]), None),\n                         (tf.constant([[7., 8.], [9., 10.]]), None)]\n      lrate = 0.1\n\n      # Note: without rescaling, the squared Fisher norm of the update\n      # is 1.74\n\n      # If the update already satisfies the norm constraint, there should\n      # be no rescaling.\n      opt = optimizer.KfacOptimizer(\n          lrate, 0.2, dummy_layer_collection(), 0.3, norm_constraint=10.,\n          name='KFAC_1')\n      coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)\n      self.assertAlmostEqual(1., sess.run(coeff), places=5)\n\n      # If the update violates the constraint, it should be rescaled to\n      # be on the constraint boundary.\n      opt = optimizer.KfacOptimizer(\n          lrate, 0.2, dummy_layer_collection(), 0.3, norm_constraint=0.5,\n          name='KFAC_2')\n      coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)\n      sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)\n      sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad\n      self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)\n\n  def testUpdateVelocities(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      layers = lc.LayerCollection()\n      layers.register_categorical_predictive_distribution(tf.constant([1.0]))\n      opt = optimizer.KfacOptimizer(\n          0.1, 0.2, layers, 0.3, momentum=0.5, momentum_type='regular')\n      x = tf.get_variable('x', initializer=tf.ones((2, 2)))\n      y = tf.get_variable('y', initializer=tf.ones((2, 2)) * 2)\n      vec1 = tf.ones((2, 2)) * 3\n      vec2 = tf.ones((2, 2)) * 4\n\n      model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n      update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)\n      opt_vars = [\n          v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n          if v not in model_vars\n      ]\n\n      sess.run(tf.global_variables_initializer())\n      old_opt_vars = sess.run(opt_vars)\n\n      # Optimizer vars start out at 0.\n      for opt_var in old_opt_vars:\n        self.assertAllEqual(sess.run(tf.zeros_like(opt_var)), opt_var)\n\n      sess.run(update_op)\n      new_opt_vars = sess.run(opt_vars)\n      # After one update, the velocities are equal to the vectors.\n      for vec, opt_var in zip([vec1, vec2], new_opt_vars):\n        self.assertAllEqual(sess.run(vec), opt_var)\n\n      sess.run(update_op)\n      final_opt_vars = sess.run(opt_vars)\n      for first, second in zip(new_opt_vars, final_opt_vars):\n        self.assertFalse(np.equal(first, second).all())\n\n  def testApplyGradients(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      layer_collection = lc.LayerCollection()\n\n      inputs = tf.ones((2, 1)) * 2\n      weights_val = np.ones((1, 1), dtype=np.float32) * 3.\n      weights = tf.get_variable('w', initializer=tf.constant(weights_val))\n      bias = tf.get_variable(\n          'b', initializer=tf.zeros_initializer(), shape=(1, 1))\n      output = tf.matmul(inputs, weights) + bias\n\n      layer_collection.register_fully_connected((weights, bias), inputs, output)\n\n      preds = output\n\n      targets = tf.constant([[0.34], [1.56]])\n      output = tf.reduce_mean(tf.square(targets - preds))\n\n      layer_collection.register_squared_error_loss(preds)\n\n      opt = optimizer.KfacOptimizer(\n          0.1,\n          0.2,\n          layer_collection,\n          cov_ema_decay=0.3,\n          momentum=0.5,\n          momentum_type='regular')\n      (cov_update_thunks,\n       inv_update_thunks) = opt.make_vars_and_create_op_thunks()\n      cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)\n      inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)\n\n      grads_and_vars = opt.compute_gradients(output, [weights, bias])\n      all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]\n\n      op = opt.apply_gradients(grads_and_vars)\n\n      sess.run(tf.global_variables_initializer())\n      old_vars = sess.run(all_vars)\n      sess.run(cov_update_ops)\n      sess.run(inv_update_ops)\n      sess.run(op)\n      new_vars = sess.run(all_vars)\n\n      for old_var, new_var in zip(old_vars, new_vars):\n        self.assertNotEqual(old_var, new_var)\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/periodic_inv_cov_update_kfac_opt_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for l.d.tf.optimizers.python.PeriodicInvCovUpdateKfacOpt class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport sonnet as snt\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import layer_collection\nfrom kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt\nfrom kfac.python.ops.tensormatch import graph_search\n\n_BATCH_SIZE = 128\n\n\ndef _construct_layer_collection(layers, all_logits, var_list):\n  for idx, logits in enumerate(all_logits):\n    tf.logging.info(\"Registering logits: %s\", logits)\n    with tf.variable_scope(tf.get_variable_scope(), reuse=(idx > 0)):\n      layers.register_categorical_predictive_distribution(\n          logits, name=\"register_logits\")\n  batch_size = all_logits[0].shape.as_list()[0]\n  vars_to_register = var_list if var_list else tf.trainable_variables()\n  graph_search.register_layers(layers, vars_to_register, batch_size)\n\n\nclass PeriodicInvCovUpdateKfacOptTest(tf.test.TestCase):\n\n  def test_train(self):\n    image = tf.random_uniform(shape=(_BATCH_SIZE, 784), maxval=1.)\n    labels = tf.random_uniform(shape=(_BATCH_SIZE,), maxval=10, dtype=tf.int32)\n    labels_one_hot = tf.one_hot(labels, 10)\n\n    model = snt.Sequential([snt.BatchFlatten(), snt.nets.MLP([128, 128, 10])])\n    logits = model(image)\n    all_losses = tf.nn.softmax_cross_entropy_with_logits_v2(\n        logits=logits, labels=labels_one_hot)\n    loss = tf.reduce_mean(all_losses)\n    layers = layer_collection.LayerCollection()\n    optimizer = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt(\n        invert_every=10,\n        cov_update_every=1,\n        learning_rate=0.03,\n        cov_ema_decay=0.95,\n        damping=100.,\n        layer_collection=layers,\n        momentum=0.9,\n        num_burnin_steps=0,\n        placement_strategy=\"round_robin\")\n    _construct_layer_collection(layers, [logits], tf.trainable_variables())\n\n    train_step = optimizer.minimize(loss)\n    counter = optimizer.counter\n    max_iterations = 50\n\n    with self.test_session() as sess:\n      sess.run(tf.global_variables_initializer())\n      coord = tf.train.Coordinator()\n      tf.train.start_queue_runners(sess=sess, coord=coord)\n      for iteration in range(max_iterations):\n        sess.run([loss, train_step])\n        counter_ = sess.run(counter)\n        self.assertEqual(counter_, iteration + 1.0)\n\n\nif __name__ == \"__main__\":\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/kernel_tests/utils_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for kfac.utils.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport numpy as np\nimport tensorflow.compat.v1 as tf\nfrom kfac.python.ops import utils\n\n\nclass SequenceDictTest(tf.test.TestCase):\n\n  def testSequenceDictInit(self):\n    seq_dict = utils.SequenceDict()\n    self.assertFalse(seq_dict._dict)\n\n  def testSequenceDictInitWithIterable(self):\n    reg_dict = {'a': 'foo', 'b': 'bar'}\n    itr = zip(reg_dict.keys(), reg_dict.values())\n    seq_dict = utils.SequenceDict(itr)\n    self.assertEqual(reg_dict, seq_dict._dict)\n\n  def testGetItemSingleKey(self):\n    seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})\n    self.assertEqual('foo', seq_dict['a'])\n\n  def testGetItemMultipleKeys(self):\n    seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})\n    self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])\n\n  def testSetItemSingleKey(self):\n    seq_dict = utils.SequenceDict()\n    seq_dict['a'] = 'foo'\n    self.assertEqual([('a', 'foo')], seq_dict.items())\n\n  def testSetItemMultipleKeys(self):\n    seq_dict = utils.SequenceDict()\n    keys = ('a', 'b', 'c')\n    values = ('foo', 'bar', 'baz')\n    seq_dict[keys] = values\n    self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())\n\n\nclass SubGraphTest(tf.test.TestCase):\n\n  def testBasicGraph(self):\n    a = tf.constant([[1., 2.], [3., 4.]])\n    b = tf.constant([[5., 6.], [7., 8.]])\n    c = a + b\n    d = a * b\n    sub_graph = utils.SubGraph((c,))\n    self.assertTrue(sub_graph.is_member(a))\n    self.assertTrue(sub_graph.is_member(b))\n    self.assertTrue(sub_graph.is_member(c))\n    self.assertFalse(sub_graph.is_member(d))\n\n  def testRepeatedAdds(self):\n    a = tf.constant([[1., 2.], [3., 4.]])\n    b = tf.constant([[5., 6.], [7., 8.]])\n    c = a + b + a  # note that a appears twice in this graph\n    sub_graph = utils.SubGraph((c,))\n    self.assertTrue(sub_graph.is_member(a))\n    self.assertTrue(sub_graph.is_member(b))\n    self.assertTrue(sub_graph.is_member(c))\n\n  def testFilterList(self):\n    a = tf.constant([[1., 2.], [3., 4.]])\n    b = tf.constant([[5., 6.], [7., 8.]])\n    c = a + b\n    d = a * b\n    sub_graph = utils.SubGraph((c,))\n    input_list = [b, d]\n    filtered_list = sub_graph.filter_list(input_list)\n    self.assertEqual(filtered_list, [b])\n\n  def testVariableUses(self):\n    with tf.Graph().as_default():\n      var = tf.get_variable('var', shape=[10, 10])\n      resource_var = tf.get_variable(\n          'resource_var', shape=[10, 10], use_resource=True)\n      x = tf.zeros([3, 10])\n      z0 = tf.matmul(x, var) + tf.matmul(x, var)\n      z1 = tf.matmul(x, resource_var)\n      sub_graph = utils.SubGraph((z0, z1))\n      self.assertEqual(2, sub_graph.variable_uses(var))\n      self.assertEqual(1, sub_graph.variable_uses(resource_var))\n\n  def testVariableUsesRelayOps(self):\n    with tf.Graph().as_default():\n      a = tf.get_variable(\"a\", shape=[2, 2])\n      b = tf.get_variable(\"b\", shape=[2, 2])\n      ai = tf.identity(a)\n      c = tf.matmul(ai, b)\n      d = tf.matmul(ai, b)\n\n      sub_graph = utils.SubGraph((c, d))\n      self.assertEqual(2, sub_graph.variable_uses(a))\n      self.assertEqual(2, sub_graph.variable_uses(b))\n\n\nclass UtilsTest(tf.test.TestCase):\n\n  def _fully_connected_layer_params(self):\n    weights_part = tf.constant([[1., 2.], [4., 3.]])\n    bias_part = tf.constant([1., 2.])\n    return (weights_part, bias_part)\n\n  def _conv_layer_params(self):\n    weights_shape = 2, 2, 3, 4\n    biases_shape = weights_shape[-1:]\n    weights = tf.constant(np.random.RandomState(0).randn(*weights_shape))\n    biases = tf.constant(np.random.RandomState(1).randn(*biases_shape))\n    return (weights, biases)\n\n  def testFullyConnectedLayerParamsTupleToMat2d(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      layer_params = self._fully_connected_layer_params()\n      output = utils.layer_params_to_mat2d(layer_params)\n      self.assertListEqual([3, 2], output.get_shape().as_list())\n      self.assertAllClose(\n          sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))\n\n  def testFullyConnectedLayerParamsTensorToMat2d(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      layer_params = self._fully_connected_layer_params()\n      output = utils.layer_params_to_mat2d(layer_params[0])\n      self.assertListEqual([2, 2], output.get_shape().as_list())\n      self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))\n\n  def testConvLayerParamsTupleToMat2d(self):\n    with tf.Graph().as_default():\n      tf.set_random_seed(200)\n      layer_params = self._conv_layer_params()\n      output = utils.layer_params_to_mat2d(layer_params)\n      self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())\n\n  def testKron(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      mat1 = np.array([[1., 2.], [3., 4.]])\n      mat2 = np.array([[5., 6.], [7., 8.]])\n      mat1_tf = tf.constant(mat1)\n      mat2_tf = tf.constant(mat2)\n      ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))\n      ans_np = np.kron(mat1, mat2)\n      self.assertAllClose(ans_tf, ans_np)\n\n  def testMat2dToFullyConnectedLayerParamsTuple(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      vector_template = self._fully_connected_layer_params()\n      mat2d = tf.constant([[5., 4.], [3., 2.], [1., 0.]])\n\n      output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))\n\n      self.assertIsInstance(output, tuple)\n      self.assertEqual(len(output), 2)\n      a, b = output\n      self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))\n      self.assertAllClose(b, np.array([1., 0.]))\n\n  def testMat2dToFullyConnectedLayerParamsTensor(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      vector_template = self._fully_connected_layer_params()[0]\n      mat2d = tf.constant([[5., 4.], [3., 2.]])\n\n      output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))\n\n      self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))\n\n  def testTensorsToColumn(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n\n      vector = tf.constant(np.array([[0., 1.], [2., 3.]]))\n      output = utils.tensors_to_column(vector)\n      self.assertListEqual([4, 1], output.get_shape().as_list())\n      self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])\n\n      vector = self._fully_connected_layer_params()\n      output = utils.tensors_to_column(vector)\n      self.assertListEqual([6, 1], output.get_shape().as_list())\n      self.assertAllClose(\n          sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])\n\n      vector = list(vector)\n      vector.append(tf.constant([[6.], [7.], [8.], [9.]]))\n\n      output = utils.tensors_to_column(vector)\n      self.assertListEqual([10, 1], output.get_shape().as_list())\n      self.assertAllClose(\n          sess.run(output),\n          np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])\n\n  def testColumnToTensors(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n\n      vector_template = tf.constant(np.array([[0., 1.], [2., 3.]]))\n      colvec = tf.constant(np.arange(4.)[:, None])\n      output = sess.run(utils.column_to_tensors(vector_template, colvec))\n      self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))\n\n      vector_template = self._fully_connected_layer_params()\n      colvec = tf.constant(np.arange(6.)[:, None])\n      output = sess.run(utils.column_to_tensors(vector_template, colvec))\n\n      self.assertIsInstance(output, tuple)\n      self.assertEqual(len(output), 2)\n      a, b = output\n      self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))\n      self.assertAllClose(b, np.array([4., 5.]))\n\n      vector_template = list(vector_template)\n      vector_template.append(tf.constant([[6.], [7.], [8.], [9.]]))\n      colvec = tf.constant(np.arange(10.)[:, None])\n      output = sess.run(utils.column_to_tensors(vector_template, colvec))\n      self.assertIsInstance(output, tuple)\n      self.assertEqual(len(output), 3)\n      a, b, c = output\n      self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))\n      self.assertAllClose(b, np.array([4., 5.]))\n      self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))\n\n  def testPosDefInvCholesky(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      np.random.seed(0)\n      square = lambda x: np.dot(x, x.T)\n\n      size = 3\n      x = square(np.random.randn(size, size))\n      damp = 0.1\n      identity = tf.eye(size, dtype=tf.float64)\n\n      tf_inv = utils.posdef_inv_cholesky(tf.constant(x), identity, damp)\n      np_inv = np.linalg.inv(x + damp * np.eye(size))\n      self.assertAllClose(sess.run(tf_inv), np_inv)\n\n  def testPosDefInvMatrixInverse(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      tf.set_random_seed(200)\n      np.random.seed(0)\n      square = lambda x: np.dot(x, x.T)\n\n      size = 3\n      x = square(np.random.randn(size, size))\n      damp = 0.1\n      identity = tf.eye(size, dtype=tf.float64)\n\n      tf_inv = utils.posdef_inv_matrix_inverse(tf.constant(x), identity, damp)\n      np_inv = np.linalg.inv(x + damp * np.eye(size))\n      self.assertAllClose(sess.run(tf_inv), np_inv)\n\n  def testBatchExecute(self):\n    \"\"\"Ensure batch_execute runs in a round-robin fashion.\"\"\"\n\n    def increment_var(var):\n      return lambda: var.assign_add(1)\n\n    with tf.Graph().as_default(), self.test_session() as sess:\n      i = tf.get_variable('i', initializer=0)\n      accumulators = [\n          tf.get_variable('var%d' % j, initializer=0) for j in range(3)\n      ]\n      thunks = [increment_var(var) for var in accumulators]\n      increment_accumulators = utils.batch_execute(i, thunks, 2)\n      increment_i = i.assign_add(1)\n\n      sess.run(tf.global_variables_initializer())\n\n      # Ensure one op per thunk.\n      self.assertEqual(3, len(increment_accumulators))\n\n      # Ensure round-robin execution.\n      values = []\n      for _ in range(5):\n        sess.run(increment_accumulators)\n        sess.run(increment_i)\n        values.append(sess.run(accumulators))\n      self.assertAllClose(\n          [\n              [1, 1, 0],  #\n              [2, 1, 1],  #\n              [2, 2, 2],  #\n              [3, 3, 2],  #\n              [4, 3, 3]\n          ],\n          values)\n\n  def testExtractConvolutionPatches(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      batch_size = 10\n      image_spatial_shape = [9, 10, 11]\n      in_channels = out_channels = 32\n      kernel_spatial_shape = [5, 3, 3]\n      spatial_strides = [1, 2, 1]\n      spatial_dilation = [1, 1, 1]\n      padding = 'SAME'\n\n      images = tf.random_uniform(\n          [batch_size] + image_spatial_shape + [in_channels], seed=0)\n      kernel_shape = kernel_spatial_shape + [in_channels, out_channels]\n      kernel = tf.random_uniform(kernel_shape, seed=1)\n\n      # Ensure shape matches expectation.\n      patches = utils.extract_convolution_patches(\n          images,\n          kernel_shape,\n          padding,\n          strides=spatial_strides,\n          dilation_rate=spatial_dilation)\n      result_spatial_shape = (\n          patches.shape.as_list()[1:1 + len(image_spatial_shape)])\n      self.assertEqual(patches.shape.as_list(),\n                       [batch_size] + result_spatial_shape +\n                       kernel_spatial_shape + [in_channels])\n\n      # Ensure extract...patches() + matmul() and convolution() implementation\n      # give the same answer.\n      outputs = tf.nn.convolution(\n          images,\n          kernel,\n          padding,\n          strides=spatial_strides,\n          dilation_rate=spatial_dilation)\n\n      patches_flat = tf.reshape(\n          patches, [-1, np.prod(kernel_spatial_shape) * in_channels])\n      kernel_flat = tf.reshape(kernel, [-1, out_channels])\n      outputs_flat = tf.matmul(patches_flat, kernel_flat)\n\n      outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])\n      self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())\n\n  def testExtractPointwiseConv2dPatches(self):\n    with tf.Graph().as_default(), self.test_session() as sess:\n      batch_size = 10\n      image_height = image_width = 8\n      in_channels = out_channels = 3\n      kernel_height = kernel_width = 1\n      strides = [1, 1, 1, 1]\n      padding = 'VALID'\n\n      images = tf.random_uniform(\n          [batch_size, image_height, image_width, in_channels], seed=0)\n      kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]\n      kernel = tf.random_uniform(kernel_shape, seed=1)\n\n      # Ensure shape matches expectation.\n      patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)\n      self.assertEqual(patches.shape.as_list(), [\n          batch_size, image_height, image_width, kernel_height, kernel_width,\n          in_channels\n      ])\n\n      # Ensure extract...patches() + matmul() and conv2d() implementation\n      # give the same answer.\n      outputs = tf.nn.conv2d(images, kernel, strides, padding)\n\n      patches_flat = tf.reshape(\n          patches, [-1, kernel_height * kernel_width * in_channels])\n      kernel_flat = tf.reshape(kernel, [-1, out_channels])\n      outputs_flat = tf.matmul(patches_flat, kernel_flat)\n\n      outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])\n      self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())\n\n\nclass AccumulatorVariableTest(tf.test.TestCase):\n\n  def test_assign_to_var(self):\n    var_shape = (2, 3)\n    acc_var = utils.AccumulatorVariable(\n        name='test_acc_var', dtype=tf.float32, shape=var_shape)\n    values = [\n        3. * tf.ones(shape=var_shape), 7. * tf.ones(shape=var_shape),\n        11. * tf.ones(shape=var_shape)\n    ]\n    acc_ops = []\n    accc_ops_after_reset = []\n    for value in values:\n      acc_ops.append(acc_var.accumulate(value))\n\n    for value in values[:2]:\n      accc_ops_after_reset.append(acc_var.accumulate(value))\n\n    init_op = tf.global_variables_initializer()\n    with self.test_session() as sess:\n\n      sess.run([init_op])\n\n      for acc_op in acc_ops:\n        sess.run(acc_op)\n\n      acc_var_value = sess.run(acc_var.value)\n\n      self.assertAllEqual(acc_var_value, 7.*np.ones(shape=var_shape))\n\n      sess.run(acc_var.reset())\n\n      for acc_op in accc_ops_after_reset:\n        sess.run(acc_op)\n\n      acc_var_value = sess.run(acc_var.value)\n      self.assertAllEqual(acc_var_value, 5. * np.ones(shape=var_shape))\n\n  def test_accumulation(self):\n    var_shape = (2, 3)\n    acc_var = utils.AccumulatorVariable(\n        name='test_acc_var', shape=var_shape, dtype=tf.float32)\n    values = [\n        2. * tf.ones(shape=var_shape), 4. * tf.ones(shape=var_shape),\n        9. * tf.ones(shape=var_shape)\n    ]\n    acc_ops = []\n    accc_ops_after_reset = []\n    for value in values:\n      acc_ops.append(\n          acc_var.accumulate(value))\n\n    for value in values[:2]:\n      accc_ops_after_reset.append(\n          acc_var.accumulate(value))\n\n    init_op = tf.global_variables_initializer()\n    with self.test_session() as sess:\n      sess.run([init_op])\n\n      for acc_op in acc_ops:\n        sess.run([acc_op])\n\n      acc_var_value = sess.run(acc_var.read_value_and_reset())\n      self.assertAllEqual(acc_var_value, 5. * np.ones(shape=var_shape))\n\n      for acc_op in accc_ops_after_reset:\n        sess.run([acc_op])\n\n      acc_var_value = sess.run(acc_var.value)\n      self.assertAllEqual(acc_var_value, 3. * np.ones(shape=var_shape))\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "kfac/python/ops/__init__.py",
    "content": "\n"
  },
  {
    "path": "kfac/python/ops/curvature_matrix_vector_products.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Curvature matrix-vector multiplication.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.util import nest\nfrom kfac.python.ops import utils\n\n\nclass CurvatureMatrixVectorProductComputer(object):\n  \"\"\"Class for computing matrix-vector products for Fishers and GGNs.\n\n  In other words we compute M*v where M is the matrix, v is the vector, and\n  * refers to standard matrix/vector multiplication (not element-wise\n  multiplication).\n\n  The matrices are defined in terms of some differential quantity of the total\n  loss function with respect to a provided list of tensors (\"wrt_tensors\").\n  For example, the Fisher associated with a log-prob loss w.r.t. the\n  parameters.\n\n  The 'vecs' argument to each method are lists of tensors that must be the\n  size as the corresponding ones from \"wrt_tensors\".  They represent\n  the vector being multiplied.\n\n  \"factors\" of the matrix M are defined as matrices B such that B*B^T = M.\n  Methods that multiply by the factor B take a 'loss_inner_vecs' argument\n  instead of 'vecs', which must be a list of tensors with shapes given by the\n  corresponding XXX_inner_shapes property.\n\n  Note that matrix-vector products are not normalized by the batch size, nor\n  are any damping terms added to the results.  These things can be easily\n  applied externally, if desired.\n\n  See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf\n  and https://arxiv.org/abs/1412.1193 for more information about the\n  generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector\n  products.\n  \"\"\"\n\n  def __init__(self, layer_collection, wrt_tensors,\n               colocate_gradients_with_ops=True):\n    \"\"\"Create a CurvatureMatrixVectorProductComputer object.\n\n    Args:\n      layer_collection: A LayerCollection object where the desired loss\n        functions are registered (possibly with weighing factors).\n      wrt_tensors: A list of Tensors to compute the differential quantities\n        (defining the matrices) with respect to.  See class description for more\n        info.\n      colocate_gradients_with_ops: Whether we should request gradients be\n          colocated with their respective ops. (Default: True)\n    \"\"\"\n    self._layer_collection = layer_collection\n    self._wrt_tensors = wrt_tensors\n    self._colocate_gradients_with_ops = colocate_gradients_with_ops\n\n  @property\n  def _loss_colocation_ops(self):\n    return self._layer_collection.loss_colocation_ops\n\n  @property\n  def _losses(self):\n    return self._layer_collection.losses\n\n  @property\n  def _inputs_to_losses(self):\n    return list(loss.inputs for loss in self._losses)\n\n  @property\n  def _inputs_to_losses_flat(self):\n    return nest.flatten(self._inputs_to_losses)\n\n  @property\n  def _total_loss(self):\n    return self._layer_collection.total_loss()\n\n  def _get_loss_coeff(self, loss):\n    return self._layer_collection.loss_coeffs[loss]\n\n  # Jacobian multiplication functions:\n  def _multiply_jacobian(self, vecs):\n    \"\"\"Multiply vecs by the Jacobian of losses.\"\"\"\n    # We stop gradients at wrt_tensors to produce partial derivatives (which is\n    # what we want for Jacobians).\n    jacobian_vecs_flat = utils.fwd_gradients(\n        self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,\n        stop_gradients=self._wrt_tensors,\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n    return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)\n\n  def _multiply_jacobian_transpose(self, loss_vecs):\n    \"\"\"Multiply vecs by the transpose Jacobian of losses.\"\"\"\n    loss_vecs_flat = nest.flatten(loss_vecs)\n    # We stop gradients at wrt_tensors to produce partial derivatives (which is\n    # what we want for Jacobians).\n    return tf.gradients(\n        self._inputs_to_losses_flat,\n        self._wrt_tensors,\n        grad_ys=loss_vecs_flat,\n        stop_gradients=self._wrt_tensors,\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n\n  # Loss Fisher/GGN multiplication functions:\n  def _multiply_across_losses(self, mult_func, vecs, coeff_mode=\"regular\"):\n    products = []\n    for loss, vec in zip(self._losses, vecs):\n      with tf.colocate_with(self._loss_colocation_ops[loss]):\n        if coeff_mode == \"regular\":\n          multiplier = self._get_loss_coeff(loss)\n        elif coeff_mode == \"sqrt\":\n          multiplier = tf.sqrt(self._get_loss_coeff(loss))\n        val = mult_func(loss, vec)\n        products.append(tf.cast(multiplier, dtype=val.dtype) * val)\n    return tuple(products)\n\n  def _multiply_loss_fisher(self, loss_vecs):\n    \"\"\"Multiply loss_vecs by Fisher of total loss.\"\"\"\n    mult_func = lambda loss, vec: loss.multiply_fisher(vec)\n    return self._multiply_across_losses(mult_func, loss_vecs)\n\n  def _multiply_loss_fisher_factor(self, loss_inner_vecs):\n    \"\"\"Multiply loss_inner_vecs by factor of Fisher of total loss.\"\"\"\n    mult_func = lambda loss, vec: loss.multiply_fisher_factor(vec)\n    return self._multiply_across_losses(mult_func, loss_inner_vecs,\n                                        coeff_mode=\"sqrt\")\n\n  def _multiply_loss_fisher_factor_transpose(self, loss_vecs):\n    \"\"\"Multiply loss_vecs by transpose factor of Fisher of total loss.\"\"\"\n    mult_func = lambda loss, vec: loss.multiply_fisher_factor_transpose(vec)\n    return self._multiply_across_losses(mult_func, loss_vecs,\n                                        coeff_mode=\"sqrt\")\n\n  def _multiply_loss_ggn(self, loss_vecs):\n    \"\"\"Multiply loss_vecs by GGN of total loss.\"\"\"\n    mult_func = lambda loss, vec: loss.multiply_ggn(vec)\n    return self._multiply_across_losses(mult_func, loss_vecs)\n\n  def _multiply_loss_ggn_factor(self, loss_inner_vecs):\n    \"\"\"Multiply loss_inner_vecs by factor of GGN of total loss.\"\"\"\n    mult_func = lambda loss, vec: loss.multiply_ggn_factor(vec)\n    return self._multiply_across_losses(mult_func, loss_inner_vecs,\n                                        coeff_mode=\"sqrt\")\n\n  def _multiply_loss_ggn_factor_transpose(self, loss_vecs):\n    \"\"\"Multiply loss_vecs by transpose factor of GGN of total loss.\"\"\"\n    mult_func = lambda loss, vec: loss.multiply_ggn_factor_transpose(vec)\n    return self._multiply_across_losses(mult_func, loss_vecs,\n                                        coeff_mode=\"sqrt\")\n\n  # Matrix-vector product functions (users should directly call these):\n  def multiply_fisher(self, vecs):\n    \"\"\"Multiply vecs by Fisher of total loss.\"\"\"\n    jacobian_vecs = self._multiply_jacobian(vecs)\n    loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)\n    return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)\n\n  def multiply_fisher_factor_transpose(self, vecs):\n    \"\"\"Multiply vecs by transpose of factor of Fisher of total loss.\"\"\"\n    jacobian_vecs = self._multiply_jacobian(vecs)\n    return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)\n\n  def multiply_fisher_factor(self, loss_inner_vecs):\n    \"\"\"Multiply loss_inner_vecs by factor of Fisher of total loss.\"\"\"\n    fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor(\n        loss_inner_vecs)\n    return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)\n\n  def multiply_hessian(self, vecs):\n    \"\"\"Multiply vecs by Hessian of total loss.\"\"\"\n    return tf.gradients(\n        tf.gradients(\n            self._total_loss,\n            self._wrt_tensors,\n            colocate_gradients_with_ops=self._colocate_gradients_with_ops),\n        self._wrt_tensors,\n        grad_ys=vecs,\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n\n  def multiply_ggn(self, vecs):\n    \"\"\"Multiply vecs by generalized Gauss-Newton of total loss.\"\"\"\n    jacobian_vecs = self._multiply_jacobian(vecs)\n    loss_ggn_jacobian_vecs = self._multiply_loss_ggn(jacobian_vecs)\n    return self._multiply_jacobian_transpose(loss_ggn_jacobian_vecs)\n\n  def multiply_ggn_factor_transpose(self, vecs):\n    \"\"\"Multiply vecs by transpose of factor of GGN of total loss.\"\"\"\n    jacobian_vecs = self._multiply_jacobian(vecs)\n    return self._multiply_loss_ggn_factor_transpose(jacobian_vecs)\n\n  def multiply_ggn_factor(self, loss_inner_vecs):\n    \"\"\"Multiply loss_inner_vecs by factor of GGN of total loss.\"\"\"\n    ggn_factor_transpose_vecs = (\n        self._multiply_loss_ggn_factor(loss_inner_vecs))\n    return self._multiply_jacobian_transpose(ggn_factor_transpose_vecs)\n\n  # Shape properties for multiply_XXX_factor methods:\n  @property\n  def fisher_factor_inner_shapes(self):\n    \"\"\"Shapes required by multiply_fisher_factor.\"\"\"\n    return tuple(loss.fisher_factor_inner_shape for loss in self._losses)\n\n  @property\n  def fisher_factor_inner_static_shapes(self):\n    \"\"\"Shapes required by multiply_fisher_factor.\"\"\"\n    return tuple(loss.fisher_factor_inner_static_shape for loss in self._losses)\n\n  @property\n  def ggn_factor_inner_shapes(self):\n    \"\"\"Shapes required by multiply_generalized_gauss_newton_factor.\"\"\"\n    return tuple(loss.ggn_factor_inner_shape for loss in self._losses)\n\n  @property\n  def ggn_factor_inner_static_shapes(self):\n    \"\"\"Shapes required by multiply_generalized_gauss_newton_factor.\"\"\"\n    return tuple(loss.ggn_factor_inner_static_shape\n                 for loss in self._losses)\n"
  },
  {
    "path": "kfac/python/ops/estimator.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Defines the high-level Fisher estimator class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\n\n# Dependency imports\nimport numpy as np\nimport six\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.util import nest\nfrom kfac.python.ops import placement\nfrom kfac.python.ops import utils\n\n\n# The linter is confused.\n# pylint: disable=abstract-class-instantiated\ndef make_fisher_estimator(placement_strategy=None, **kwargs):\n  \"\"\"Creates Fisher estimator instances based on the placement strategy.\n\n  For example if the `placement_strategy` is 'round_robin' then\n  `FisherEstimatorRoundRobin` instance is returned.\n\n  Args:\n    placement_strategy: `string`, Strategy to be used for placing covariance\n      variables, covariance ops and inverse ops. Check\n      `placement.FisherEstimatorRoundRobin` for a concrete example.\n   **kwargs: Arguments to be passed into `FisherEstimator` class initializer.\n\n  Returns:\n    An instance of class which inherits from `FisherEstimator` and the mixin\n    which implements specific placement strategy. See,\n    `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and\n    `RoundRobinPlacementMixin`, as an example.\n\n  Raises:\n    ValueError: If the `placement_strategy` argument is not one of the\n    recognized options.\n  \"\"\"\n  if placement_strategy in [None, \"round_robin\"]:\n    return FisherEstimatorRoundRobin(**kwargs)\n  elif placement_strategy == \"replica_round_robin\":\n    return FisherEstimatorReplicaRoundRobin(**kwargs)\n  else:\n    raise ValueError(\n        \"Unimplemented vars and ops placement strategy : {}\".format(\n            placement_strategy))\n# pylint: enable=abstract-class-instantiated\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass FisherEstimator(object):\n  \"\"\"Fisher estimator class supporting various approximations of the Fisher.\n\n  This is an abstract base class which does not implement a strategy for\n  placing covariance variables, covariance update ops and inverse update ops.\n  The placement strategies are implemented in `placement.py`. See\n  `FisherEstimatorRoundRobin` for example of a concrete subclass with\n  a round-robin placement strategy.\n  \"\"\"\n\n  def __init__(self,\n               variables,\n               cov_ema_decay,\n               damping,\n               layer_collection,\n               exps=(-1,),\n               estimation_mode=\"gradients\",\n               colocate_gradients_with_ops=True,\n               name=\"FisherEstimator\",\n               compute_cholesky=False,\n               compute_cholesky_inverse=False,\n               compute_params_stats=False,\n               batch_size=None):\n    \"\"\"Create a FisherEstimator object.\n\n    Args:\n      variables: A `list` of variables for which to estimate the Fisher. This\n        must match the variables registered in layer_collection (if it is not\n        None).\n      cov_ema_decay: The decay factor used when calculating the covariance\n        estimate moving averages.\n      damping: float or 0D Tensor. This quantity times the identity matrix is\n        (approximately) added to the matrix being estimated.\n      layer_collection: A LayerCollection object which holds for the\n        Fisher blocks, Kronecker factors, and losses associated with the\n        graph.\n      exps: List of floats or ints. These represent the different matrix\n        powers of the approximate Fisher that the FisherEstimator will be able\n        to multiply vectors by. If the user asks for a matrix power other\n        one of these (or 1, which is always supported), there will be a\n        failure. (Default: (-1,))\n      estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be\n        'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN',\n        'exact', or 'exact_GGN'. (Default: 'gradients'). 'gradients' is the\n        basic estimation approach from the original K-FAC paper.\n        'empirical' computes the 'empirical' Fisher information matrix (which\n        uses the data's distribution for the targets, as opposed to the true\n        Fisher which uses the model's distribution) and requires that each\n        registered loss have specified targets. 'curvature_propagation' is a\n        method which estimates the Fisher using self-products of random 1/-1\n        vectors times \"half-factors\" of the Fisher, as described here:\n        https://arxiv.org/abs/1206.6464 . 'exact' is the obvious\n        generalization of Curvature Propagation to compute the exact Fisher\n        (modulo any additional diagonal or Kronecker approximations) by\n        looping over one-hot vectors for each coordinate of the output\n        instead of using 1/-1 vectors.  It is more expensive to compute than\n        the other three options by a factor equal to the output dimension,\n        roughly speaking. Finally, 'curvature_prop_GGN' and 'exact_GGN' are\n        analogous to 'curvature_prop' and 'exact', but estimate the\n        Generalized Gauss-Newton matrix (GGN).\n      colocate_gradients_with_ops: Whether we should request gradients be\n        colocated with their respective ops. (Default: True)\n      name: A string. A name given to this estimator, which is added to the\n        variable scope when constructing variables and ops.\n        (Default: \"FisherEstimator\")\n      compute_cholesky: Bool. Whether or not the FisherEstimator will be\n        able to multiply vectors by the Cholesky factor.\n        (Default: False)\n      compute_cholesky_inverse: Bool. Whether or not the FisherEstimator\n        will be able to multiply vectors by the Cholesky factor inverse.\n        (Default: False)\n      compute_params_stats: Bool. If True, we compute the first order version\n        of the statistics computed to estimate the Fisher/GGN. These correspond\n        to the `variables` method in a one-to-one fashion.  They are available\n        via the `params_stats` property.  When estimation_mode is 'empirical',\n        this will correspond to the standard parameter gradient on the loss.\n        (Default: False)\n      batch_size: The size of the mini-batch. Only needed when\n        `compute_params_stats` is True. Note that when using data parallelism\n        where the model graph and optimizer are replicated across multiple\n        devices, this should be the per-replica batch size. An example of\n        this is sharded data on the TPU, where batch_size should be set to\n        the total batch size divided by the number of shards. (Default: None)\n\n    Raises:\n      ValueError: If no losses have been registered with layer_collection.\n    \"\"\"\n    self._variables = variables\n    self._cov_ema_decay = cov_ema_decay\n    self._damping = damping\n    self._estimation_mode = estimation_mode\n    self._layer_collection = layer_collection\n    self._gradient_fns = {\n        \"gradients\": self._get_grads_lists_gradients,\n        \"empirical\": self._get_grads_lists_empirical,\n        \"curvature_prop\": self._get_grads_lists_curvature_prop,\n        \"curvature_prop_GGN\": self._get_grads_lists_curvature_prop,\n        \"exact\": self._get_grads_lists_exact,\n        \"exact_GGN\": self._get_grads_lists_exact\n    }\n    self._mat_type_table = {\n        \"gradients\": \"Fisher\",\n        \"empirical\": \"Empirical_Fisher\",\n        \"curvature_prop\": \"Fisher\",\n        \"curvature_prop_GGN\": \"GGN\",\n        \"exact\": \"Fisher\",\n        \"exact_GGN\": \"GGN\",\n    }\n\n    self._colocate_gradients_with_ops = colocate_gradients_with_ops\n\n    self._exps = exps\n    self._compute_cholesky = compute_cholesky\n    self._compute_cholesky_inverse = compute_cholesky_inverse\n\n    self._name = name\n\n    self._compute_params_stats = compute_params_stats\n    self._batch_size = batch_size\n\n    if compute_params_stats and batch_size is None:\n      raise ValueError(\"Batch size needs to be passed in when \"\n                       \"compute_params_stats is True.\")\n\n    self._finalized = False\n\n  @property\n  def variables(self):\n    return self._variables\n\n  @property\n  def damping(self):\n    return self._damping\n\n  @property\n  def blocks(self):\n    \"\"\"All registered FisherBlocks.\"\"\"\n    return self.layers.get_blocks()\n\n  @property\n  def factors(self):\n    \"\"\"All registered FisherFactors.\"\"\"\n    return self.layers.get_factors()\n\n  @property\n  def name(self):\n    return self._name\n\n  @property\n  def layers(self):\n    return self._layer_collection\n\n  @property\n  def mat_type(self):\n    return self._mat_type_table[self._estimation_mode]\n\n  @property\n  def params_stats(self):\n    return self._params_stats\n\n  @abc.abstractmethod\n  def _place_and_compute_transformation_thunks(self, thunks, params_list):\n    \"\"\"Computes transformation thunks with device placement.\n\n    Device placement will be determined by the strategy asked for when this\n    estimator was constructed.\n\n    Args:\n      thunks: A list of thunks to run. Must be in one to one correspondence\n        with the `blocks` property.\n      params_list: A list of the corresponding parameters. Must be in one to one\n        correspondence with the `blocks` property.\n\n    Returns:\n      A list (in the same order) of the returned results of the thunks,\n      with possible device placement applied.\n    \"\"\"\n    pass\n\n  def _compute_transformation(self, vecs_and_vars, transform):\n    \"\"\"Computes a block-wise transformation of a list of vectors.\n\n    Args:\n      vecs_and_vars: List of (vector, variable) pairs.\n      transform: A function of the form f(fb, vec), that\n          returns the transformed vector, where vec is the vector\n          to transform and fb is its corresponding block.\n\n    Returns:\n      A list of (transformed vector, var) pairs in the same order as\n      vecs_and_vars.\n    \"\"\"\n\n    vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)\n\n    def make_thunk(fb, params):\n      return lambda: transform(fb, vecs[params])\n\n    thunks = tuple(make_thunk(fb, params)\n                   for params, fb in self.layers.fisher_blocks.items())\n\n    params_list = tuple(params\n                        for params, _ in self.layers.fisher_blocks.items())\n\n    results = self._place_and_compute_transformation_thunks(thunks, params_list)\n\n    trans_vecs = utils.SequenceDict()\n    for params, result in zip(self.layers.fisher_blocks, results):\n      trans_vecs[params] = result\n\n    return [(trans_vecs[var], var) for _, var in vecs_and_vars]\n\n  def multiply_inverse(self, vecs_and_vars):\n    \"\"\"Multiplies the vecs by the corresponding (damped) inverses of the blocks.\n\n    Args:\n      vecs_and_vars: List of (vector, variable) pairs.\n\n    Returns:\n      A list of (transformed vector, var) pairs in the same order as\n      vecs_and_vars.\n    \"\"\"\n    return self.multiply_matpower(-1, vecs_and_vars)\n\n  def multiply(self, vecs_and_vars):\n    \"\"\"Multiplies the vectors by the corresponding (damped) blocks.\n\n    Args:\n      vecs_and_vars: List of (vector, variable) pairs.\n\n    Returns:\n      A list of (transformed vector, var) pairs in the same order as\n      vecs_and_vars.\n    \"\"\"\n    return self.multiply_matpower(1, vecs_and_vars)\n\n  def multiply_matpower(self, exp, vecs_and_vars):\n    \"\"\"Multiplies the vecs by the corresponding matrix powers of the blocks.\n\n    Args:\n      exp: A float representing the power to raise the blocks by before\n        multiplying it by the vector.\n      vecs_and_vars: List of (vector, variable) pairs.\n\n    Returns:\n      A list of (transformed vector, var) pairs in the same order as\n      vecs_and_vars.\n    \"\"\"\n    fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)\n    return self._compute_transformation(vecs_and_vars, fcn)\n\n  def multiply_cholesky(self, vecs_and_vars, transpose=False):\n    \"\"\"Multiplies the vecs by the corresponding Cholesky factors.\n\n    Args:\n      vecs_and_vars: List of (vector, variable) pairs.\n      transpose: Bool. If true the Cholesky factors are transposed before\n        multiplying the vecs. (Default: False)\n\n    Returns:\n      A list of (transformed vector, var) pairs in the same order as\n      vecs_and_vars.\n    \"\"\"\n\n    fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)\n    return self._compute_transformation(vecs_and_vars, fcn)\n\n  def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):\n    \"\"\"Mults the vecs by the inverses of the corresponding Cholesky factors.\n\n      Note: if you are using Cholesky inverse multiplication to sample from\n      a matrix-variate Gaussian you will want to multiply by the transpose.\n      Let L be the Cholesky factor of F and observe that\n\n        L^-T * L^-1 = (L * L^T)^-1 = F^-1 .\n\n      Thus we want to multiply by L^-T in order to sample from Gaussian with\n      covariance F^-1.\n\n    Args:\n      vecs_and_vars: List of (vector, variable) pairs.\n      transpose: Bool. If true the Cholesky factor inverses are transposed\n        before multiplying the vecs. (Default: False)\n\n    Returns:\n      A list of (transformed vector, var) pairs in the same order as\n      vecs_and_vars.\n    \"\"\"\n\n    fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)\n    return self._compute_transformation(vecs_and_vars, fcn)\n\n  def _instantiate_factors(self):\n    \"\"\"Instantiates FisherFactors' variables.\n\n    Raises:\n      ValueError: If estimation_mode was improperly specified at construction.\n    \"\"\"\n    blocks = self.blocks\n    tensors_to_compute_grads = [\n        block.tensors_to_compute_grads() for block in blocks\n    ]\n\n    if self._compute_params_stats:\n      tensors_to_compute_grads = tensors_to_compute_grads + self.variables\n\n    try:\n      grads_lists = self._gradient_fns[self._estimation_mode](\n          tensors_to_compute_grads)\n    except KeyError:\n      raise ValueError(\"Unrecognized value {} for estimation_mode.\".format(\n          self._estimation_mode))\n\n    if any(grad is None for grad in nest.flatten(grads_lists)):\n      tensors_flat = nest.flatten(tensors_to_compute_grads)\n      grads_flat = nest.flatten(grads_lists)\n      bad_tensors = tuple(\n          tensor for tensor, grad in zip(tensors_flat, grads_flat)\n          if grad is None)\n      bad_string = \"\"\n      for tensor in bad_tensors:\n        bad_string += \"\\t{}\\n\".format(tensor)\n      raise ValueError(\"It looks like you registered one of more tensors that \"\n                       \"the registered loss/losses don't depend on. (These \"\n                       \"returned None from tf.gradients.) The tensors were:\"\n                       \"\\n\\n\" + bad_string)\n\n    if self._compute_params_stats:\n      idx = len(blocks)\n      params_stats_unnorm = tuple(tf.add_n(grad_list)\n                                  for grad_list in grads_lists[idx:])\n\n      scalar = 1. / tf.cast(self._batch_size,\n                            dtype=params_stats_unnorm[0].dtype)\n      params_stats = utils.sprod(scalar, params_stats_unnorm)\n\n      # batch_size should be the per-replica batch size and thus we do a\n      # cross-replica mean instead of a sum here\n      self._params_stats = tuple(utils.all_average(tensor)\n                                 for tensor in params_stats)\n\n      grads_lists = grads_lists[:idx]\n\n    for grads_list, block in zip(grads_lists, blocks):\n      block.instantiate_factors(grads_list, self.damping)\n\n  def _register_matrix_functions(self):\n    for block in self.blocks:\n      for exp in self._exps:\n        block.register_matpower(exp)\n      if self._compute_cholesky:\n        block.register_cholesky()\n      if self._compute_cholesky_inverse:\n        block.register_cholesky_inverse()\n\n  def _finalize(self):\n    if not self._finalized:\n      self.layers.finalize()\n      self.layers.check_registration(self.variables)\n      self._instantiate_factors()\n      self._register_matrix_functions()\n\n    self._finalized = True\n\n  def _check_batch_sizes(self, factor):\n    \"\"\"Checks that the batch size(s) for a factor matches the reference value.\"\"\"\n\n    # Should make these messages use quote characters instead of parentheses\n    # when the bug with quote character rendering in assertion messages is\n    # fixed. See b/129476712\n    if self._batch_size is None:\n      batch_size = self.factors[0].batch_size()\n      string = (\"Batch size {} for factor (\" + factor.name + \") of type \"\n                + utils.cls_name(factor) + \" did not match value {} used by \"\n                \"factor (\" + self.factors[0].name + \") of type \"\n                + utils.cls_name(self.factors[0]) + \".\")\n    else:\n      batch_size = self._batch_size\n      string = (\"Batch size {} for factor (\" + factor.name + \") of type \"\n                + utils.cls_name(factor) + \" did not match value {} which was \"\n                \"passed to optimizer/estimator.\")\n\n    factor_batch_size = factor.batch_size()\n\n    if isinstance(batch_size, int) and isinstance(factor_batch_size, int):\n      if factor_batch_size != batch_size:\n        raise ValueError(string.format(factor_batch_size, batch_size))\n      return factor.check_partial_batch_sizes()\n\n    else:\n      message = string.format(\"(x)\", \"(y)\")\n      with tf.control_dependencies([factor.check_partial_batch_sizes()]):\n        return tf.assert_equal(factor_batch_size, batch_size, message=message)\n\n  def _create_ops_and_vars_thunks(self, scope=None):\n    \"\"\"Create thunks that make the ops and vars on demand.\n\n    This function returns 4 lists of thunks: cov_variable_thunks,\n    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.\n\n    The length of each list is the number of factors and the i-th element of\n    each list corresponds to the i-th factor (given by the \"factors\" property).\n\n    Note that the execution of these thunks must happen in a certain\n    partial order.  The i-th element of cov_variable_thunks must execute\n    before the i-th element of cov_update_thunks (and also the i-th element\n    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks\n    must execute before the i-th element of inv_update_thunks.\n\n    TL;DR (oversimplified): Execute the thunks according to the order that\n    they are returned.\n\n    Args:\n      scope: A string or None.  If None it will be set to the name of this\n        estimator (given by the name property). All thunks will execute inside\n        of a variable scope of the given name. (Default: None)\n    Returns:\n      cov_variable_thunks: A list of thunks that make the cov variables.\n      cov_update_thunks: A list of thunks that make the cov update ops.\n      inv_variable_thunks: A list of thunks that make the inv variables.\n      inv_update_thunks: A list of thunks that make the inv update ops.\n    \"\"\"\n\n    self._finalize()\n\n    scope = self.name if scope is None else scope\n\n    cov_variable_thunks = [\n        self._create_cov_variable_thunk(factor, scope)\n        for factor in self.factors\n    ]\n    cov_update_thunks = [\n        self._create_cov_update_thunk(factor, scope) for factor in self.factors\n    ]\n    inv_variable_thunks = [\n        self._create_inv_variable_thunk(factor, scope)\n        for factor in self.factors\n    ]\n    inv_update_thunks = [\n        self._create_inv_update_thunk(factor, scope) for factor in self.factors\n    ]\n\n    return (cov_variable_thunks, cov_update_thunks,\n            inv_variable_thunks, inv_update_thunks)\n\n  @abc.abstractmethod\n  def create_ops_and_vars_thunks(self, scope=None):\n    \"\"\"Create thunks that make the ops and vars on demand with device placement.\n\n    This function returns 4 lists of thunks: cov_variable_thunks,\n    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.\n\n    The length of each list is the number of factors and the i-th element of\n    each list corresponds to the i-th factor (given by the \"factors\" property).\n\n    Note that the execution of these thunks must happen in a certain\n    partial order.  The i-th element of cov_variable_thunks must execute\n    before the i-th element of cov_update_thunks (and also the i-th element\n    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks\n    must execute before the i-th element of inv_update_thunks.\n\n    TL;DR (oversimplified): Execute the thunks according to the order that\n    they are returned.\n\n    Device placement will be determined by the strategy asked for when this\n    estimator was constructed.\n\n    Args:\n      scope: A string or None.  If None it will be set to the name of this\n        estimator (given by the name property). All thunks will execute inside\n        of a variable scope of the given name. (Default: None)\n    Returns:\n      cov_variable_thunks: A list of thunks that make the cov variables.\n      cov_update_thunks: A list of thunks that make the cov update ops.\n      inv_variable_thunks: A list of thunks that make the inv variables.\n      inv_update_thunks: A list of thunks that make the inv update ops.\n    \"\"\"\n    pass\n\n  def make_vars_and_create_op_thunks(self, scope=None):\n    \"\"\"Make vars and create op thunks with device placement.\n\n    Similar to create_ops_and_vars_thunks but actually makes the variables\n    instead of returning thunks that make them.\n\n    Device placement will be determined by the strategy asked for when this\n    estimator was constructed.\n\n    Args:\n      scope: A string or None.  If None it will be set to the name of this\n        estimator (given by the name property). All variables will be created,\n        and all thunks will execute, inside of a variable scope of the given\n        name. (Default: None)\n\n    Returns:\n      cov_update_thunks: List of cov update thunks. Corresponds one-to-one with\n        the list of factors given by the \"factors\" property.\n      inv_update_thunks: List of inv update thunks. Corresponds one-to-one with\n        the list of factors given by the \"factors\" property.\n    \"\"\"\n    (cov_variable_thunks, cov_update_thunks, inv_variable_thunks,\n     inv_update_thunks) = self.create_ops_and_vars_thunks(scope=scope)\n\n    for thunk in cov_variable_thunks:\n      thunk()\n\n    for thunk in inv_variable_thunks:\n      thunk()\n\n    return cov_update_thunks, inv_update_thunks\n\n  def get_cov_vars(self):\n    \"\"\"Returns all covariance variables associated with each Fisher factor.\n\n    Note the returned list also includes additional factor specific covariance\n    variables.\n\n    Returns: List of list. The number of inner lists is equal to number of\n      factors. And each inner list contains all covariance\n      variables for that factor.\n    \"\"\"\n    return tuple(factor.get_cov_vars() for factor in self.factors)\n\n  def get_inv_vars(self):\n    \"\"\"Returns all covariance variables associated with each Fisher factor.\n\n    Note the returned list also includes additional factor specific covariance\n    variables.\n\n    Returns: List of list. The number of inner lists is equal to number of\n      factors. And each inner list contains all inverse computation related\n      variables for that factor.\n    \"\"\"\n    return tuple(factor.get_inv_vars() for factor in self.factors)\n\n  def _create_cov_variable_thunk(self, factor, scope):\n    \"\"\"Constructs a covariance variable thunk for a single FisherFactor.\"\"\"\n\n    def thunk():\n      with tf.variable_scope(scope):\n        return factor.instantiate_cov_variables()\n\n    return thunk\n\n  def _create_cov_update_thunk(self, factor, scope):\n    \"\"\"Constructs a covariance update thunk for a single FisherFactor.\"\"\"\n\n    def thunk(should_decay=True):\n      if isinstance(should_decay, bool):\n        ema_decay = self._cov_ema_decay if should_decay else 1.0\n      else:\n        ema_decay = tf.cond(should_decay,\n                            lambda: self._cov_ema_decay,\n                            lambda: 1.0)\n      ema_weight = 1.0\n\n      with tf.variable_scope(scope):\n        with tf.control_dependencies([self._check_batch_sizes(factor)]):\n          return factor.make_covariance_update_op(ema_decay, ema_weight)\n\n    return thunk\n\n  def _create_inv_variable_thunk(self, factor, scope):\n    \"\"\"Constructs a inverse variable thunk for a single FisherFactor.\"\"\"\n\n    def thunk():\n      with tf.variable_scope(scope):\n        return factor.instantiate_inv_variables()\n\n    return thunk\n\n  def _create_inv_update_thunk(self, factor, scope):\n    \"\"\"Constructs an inverse update thunk for a single FisherFactor.\"\"\"\n\n    def thunk():\n      with tf.variable_scope(scope):\n        return tf.group(factor.make_inverse_update_ops())\n\n    return thunk\n\n  def _get_grads_lists_gradients(self, tensors):\n    # Passing in a list of loss values is better than passing in the sum as\n    # the latter creates unnecessary ops on the default device\n    grads_flat = tf.gradients(\n        self.layers.eval_losses(target_mode=\"sample\", coeff_mode=\"sqrt\"),\n        nest.flatten(tensors),\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n    grads_all = nest.pack_sequence_as(tensors, grads_flat)\n    return tuple((grad,) for grad in grads_all)\n\n  def _get_grads_lists_empirical(self, tensors):\n    # Passing in a list of loss values is better than passing in the sum as\n    # the latter creates unnessesary ops on the default device\n    grads_flat = tf.gradients(\n        self.layers.eval_losses(target_mode=\"data\", coeff_mode=\"regular\"),\n        nest.flatten(tensors),\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n    grads_all = nest.pack_sequence_as(tensors, grads_flat)\n    return tuple((grad,) for grad in grads_all)\n\n  def _get_transformed_random_signs(self):\n    if self.mat_type == \"Fisher\":\n      mult_func = lambda loss, index: loss.multiply_fisher_factor(index)\n      inner_shape_func = lambda loss: loss.fisher_factor_inner_shape\n    elif self.mat_type == \"GGN\":\n      mult_func = lambda loss, index: loss.multiply_ggn_factor(index)\n      inner_shape_func = lambda loss: loss.ggn_factor_inner_shape\n\n    transformed_random_signs = []\n    for loss in self.layers.losses:\n      with tf.colocate_with(self.layers.loss_colocation_ops[loss]):\n        value = mult_func(loss,\n                          utils.generate_random_signs(inner_shape_func(loss),\n                                                      dtype=loss.dtype))\n        coeff = tf.cast(self.layers.loss_coeffs[loss], dtype=value.dtype)\n        transformed_random_signs.append(tf.sqrt(coeff) * value)\n    return transformed_random_signs\n\n  def _get_grads_lists_curvature_prop(self, tensors):\n    loss_inputs = list(loss.inputs for loss in self.layers.losses)\n    transformed_random_signs = self._get_transformed_random_signs()\n    grads_flat = tf.gradients(\n        nest.flatten(loss_inputs),\n        nest.flatten(tensors),\n        grad_ys=nest.flatten(transformed_random_signs),\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n    grads_all = nest.pack_sequence_as(tensors, grads_flat)\n    return tuple((grad,) for grad in grads_all)\n\n  def _get_grads_lists_exact(self, tensors):\n    if self.mat_type == \"Fisher\":\n      # pylint: disable=g-long-lambda\n      mult_func = (lambda loss, index:\n                   loss.multiply_fisher_factor_replicated_one_hot(index))\n      inner_shape_func = lambda loss: loss.fisher_factor_inner_static_shape\n    elif self.mat_type == \"GGN\":\n      # pylint: disable=g-long-lambda\n      mult_func = (lambda loss, index:\n                   loss.multiply_ggn_factor_replicated_one_hot(index))\n      inner_shape_func = lambda loss: loss.fisher_ggn_inner_static_shape\n\n    # Loop over all coordinates of all losses.\n    grads_all = []\n    for loss in self.layers.losses:\n      with tf.colocate_with(self.layers.loss_colocation_ops[loss]):\n        for index in np.ndindex(*inner_shape_func(loss)[1:]):\n          value = mult_func(loss, index)\n          coeff = tf.cast(self.layers.loss_coeffs[loss], dtype=value.dtype)\n          transformed_one_hot = tf.sqrt(coeff) * value\n          grads_flat = tf.gradients(\n              loss.inputs,\n              nest.flatten(tensors),\n              grad_ys=transformed_one_hot,\n              colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n          grads_all.append(nest.pack_sequence_as(tensors, grads_flat))\n    return tuple(zip(*grads_all))\n\n\nclass FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,\n                                FisherEstimator):\n  \"\"\"FisherEstimator which provides round robin device placement strategy.\"\"\"\n  pass\n\n\nclass FisherEstimatorReplicaRoundRobin(\n    placement.ReplicaRoundRobinPlacementMixin,\n    FisherEstimator):\n  \"\"\"FisherEstimator which provides round robin replica placement strategy.\"\"\"\n  pass\n"
  },
  {
    "path": "kfac/python/ops/fisher_blocks.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"FisherBlock definitions.\n\nThis library contains classes for estimating blocks in a model's Fisher\nInformation matrix. Suppose one has a model that parameterizes a posterior\ndistribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its\nFisher Information matrix is given by,\n\n  F(params) = E[ v(x, y, params) v(x, y, params)^T ]\n\nwhere,\n\n  v(x, y, params) = (d / d params) log p(y | x, params)\n\nand the expectation is taken with respect to the data's distribution for 'x' and\nthe model's posterior distribution for 'y',\n\n  x ~ p(x)\n  y ~ p(y | x, params)\n\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\n\n# Dependency imports\nimport six\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.util import nest\nfrom kfac.python.ops import fisher_factors\nfrom kfac.python.ops import utils\n\n# For blocks corresponding to convolutional layers, or any type of block where\n# the parameters can be thought of as being replicated in time or space,\n# we want to adjust the scale of the damping by\n#   damping /= num_replications ** NORMALIZE_DAMPING_POWER\nNORMALIZE_DAMPING_POWER = 1.0\n\n# Methods for adjusting damping for FisherBlocks. See\n# compute_pi_adjusted_damping() for details.\nPI_OFF_NAME = \"off\"\nPI_TRACENORM_NAME = \"tracenorm\"\nPI_TYPE = PI_TRACENORM_NAME\n\n\ndef set_global_constants(normalize_damping_power=None, pi_type=None):\n  \"\"\"Sets various global constants used by the classes in this module.\"\"\"\n  global NORMALIZE_DAMPING_POWER\n  global PI_TYPE\n\n  if normalize_damping_power is not None:\n    NORMALIZE_DAMPING_POWER = normalize_damping_power\n\n  if pi_type is not None:\n    PI_TYPE = pi_type\n\n\ndef normalize_damping(damping, num_replications):\n  \"\"\"Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.\"\"\"\n  if NORMALIZE_DAMPING_POWER:\n    return damping / (num_replications ** NORMALIZE_DAMPING_POWER)\n  return damping\n\n\ndef compute_pi_tracenorm(left_cov, right_cov):\n  \"\"\"Computes the scalar constant pi for Tikhonov regularization/damping.\n\n  pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )\n  See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.\n\n  Args:\n    left_cov: A LinearOperator object. The left Kronecker factor \"covariance\".\n    right_cov: A LinearOperator object. The right Kronecker factor \"covariance\".\n\n  Returns:\n    The computed scalar constant pi for these Kronecker Factors (as a Tensor).\n  \"\"\"\n  # Instead of dividing by the dim of the norm, we multiply by the dim of the\n  # other norm. This works out the same in the ratio.\n  left_norm = left_cov.trace() * int(right_cov.domain_dimension)\n  right_norm = right_cov.trace() * int(left_cov.domain_dimension)\n\n  def normal_case():\n    assert_positive = tf.assert_positive(\n        right_norm,\n        message=\"PI computation, trace of right cov matrix should be positive. \"\n        \"Note that most likely cause of this error is that the optimizer \"\n        \"diverged (e.g. due to bad hyperparameters).\")\n    with tf.control_dependencies([assert_positive]):\n      return tf.sqrt(left_norm / right_norm)\n\n  def zero_case():\n    return tf.constant(1.0, dtype=left_norm.dtype)\n\n  return tf.cond(tf.equal(left_norm * right_norm, 0.0), zero_case, normal_case)\n\n\ndef compute_pi_adjusted_damping(left_cov, right_cov, damping):\n\n  if PI_TYPE == PI_TRACENORM_NAME:\n    pi = compute_pi_tracenorm(left_cov, right_cov)\n    damping = tf.cast(damping, dtype=pi.dtype)\n    return (damping * pi, damping / pi)\n\n  elif PI_TYPE == PI_OFF_NAME:\n    return (damping, damping)\n\n\nclass PackagedFunc(object):\n  \"\"\"A Python thunk with a stable ID.\n\n  Enables stable names for lambdas.\n  \"\"\"\n\n  def __init__(self, func, func_id):\n    \"\"\"Initializes PackagedFunc.\n\n    Args:\n      func: a zero-arg Python function.\n      func_id: a hashable, function that produces a hashable, or a list/tuple\n        thereof.\n    \"\"\"\n    self._func = func\n    func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)\n    self._func_id = func_id\n\n  def __call__(self):\n    return self._func()\n\n  @property\n  def func_id(self):\n    \"\"\"A hashable identifier for this function.\"\"\"\n    return tuple(elt() if callable(elt) else elt for elt in self._func_id)\n\n\ndef _package_func(func, func_id):\n  return PackagedFunc(func, func_id)\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass FisherBlock(object):\n  \"\"\"Abstract base class for objects modeling approximate Fisher matrix blocks.\n\n  Subclasses must implement register_matpower, multiply_matpower,\n  instantiate_factors, tensors_to_compute_grads, and num_registered_towers\n  methods.\n  \"\"\"\n\n  def __init__(self, layer_collection):\n    self._layer_collection = layer_collection\n\n  @abc.abstractmethod\n  def instantiate_factors(self, grads_list, damping):\n    \"\"\"Creates and registers the component factors of this Fisher block.\n\n    Args:\n      grads_list: A list gradients (each a Tensor or tuple of Tensors) with\n          respect to the tensors returned by tensors_to_compute_grads() that\n          are to be used to estimate the block.\n      damping: The damping factor (float or Tensor).\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def register_matpower(self, exp):\n    \"\"\"Registers a matrix power to be computed by the block.\n\n    Args:\n      exp: A float representing the power to raise the block by.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def register_cholesky(self):\n    \"\"\"Registers a Cholesky factor to be computed by the block.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def register_cholesky_inverse(self):\n    \"\"\"Registers an inverse Cholesky factor to be computed by the block.\"\"\"\n    pass\n\n  def register_inverse(self):\n    \"\"\"Registers a matrix inverse to be computed by the block.\"\"\"\n    self.register_matpower(-1)\n\n  @abc.abstractmethod\n  def multiply_matpower(self, vector, exp):\n    \"\"\"Multiplies the vector by the (damped) matrix-power of the block.\n\n    Args:\n      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.\n      exp: A float representing the power to raise the block by before\n        multiplying it by the vector.\n\n    Returns:\n      The vector left-multiplied by the (damped) matrix-power of the block.\n    \"\"\"\n    pass\n\n  def multiply_inverse(self, vector):\n    \"\"\"Multiplies the vector by the (damped) inverse of the block.\n\n    Args:\n      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.\n\n    Returns:\n      The vector left-multiplied by the (damped) inverse of the block.\n    \"\"\"\n    return self.multiply_matpower(vector, -1)\n\n  def multiply(self, vector):\n    \"\"\"Multiplies the vector by the (damped) block.\n\n    Args:\n      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.\n\n    Returns:\n      The vector left-multiplied by the (damped) block.\n    \"\"\"\n    return self.multiply_matpower(vector, 1)\n\n  @abc.abstractmethod\n  def multiply_cholesky(self, vector, transpose=False):\n    \"\"\"Multiplies the vector by the (damped) Cholesky-factor of the block.\n\n    Args:\n      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.\n      transpose: Bool. If true the Cholesky factor is transposed before\n        multiplying the vector. (Default: False)\n\n    Returns:\n      The vector left-multiplied by the (damped) Cholesky-factor of the block.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_cholesky_inverse(self, vector, transpose=False):\n    \"\"\"Multiplies vector by the (damped) inverse Cholesky-factor of the block.\n\n    Args:\n      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.\n      transpose: Bool. If true the Cholesky factor inverse is transposed\n        before multiplying the vector. (Default: False)\n    Returns:\n      Vector left-multiplied by (damped) inverse Cholesky-factor of the block.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def tensors_to_compute_grads(self):\n    \"\"\"Returns the Tensor(s) with respect to which this FisherBlock needs grads.\n    \"\"\"\n    pass\n\n  @abc.abstractproperty\n  def num_registered_towers(self):\n    \"\"\"Number of towers registered for this FisherBlock.\n\n    Typically equal to the number of towers in a multi-tower setup.\n    \"\"\"\n    pass\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass FullFB(FisherBlock):\n  \"\"\"Base class for blocks that use full matrix representations (no approx).\"\"\"\n\n  def register_matpower(self, exp):\n    self._factor.register_matpower(exp, self._damping_func)\n\n  def register_cholesky(self):\n    self._factor.register_cholesky(self._damping_func)\n\n  def register_cholesky_inverse(self):\n    self._factor.register_cholesky_inverse(self._damping_func)\n\n  def _multiply_matrix(self, matrix, vector, transpose=False):\n    vector_flat = utils.tensors_to_column(vector)\n    out_flat = matrix.matmul(vector_flat, adjoint=transpose)\n    return utils.column_to_tensors(vector, out_flat)\n\n  def multiply_matpower(self, vector, exp):\n    matrix = self._factor.get_matpower(exp, self._damping_func)\n    return self._multiply_matrix(matrix, vector)\n\n  def multiply_cholesky(self, vector, transpose=False):\n    matrix = self._factor.get_cholesky(self._damping_func)\n    return self._multiply_matrix(matrix, vector, transpose=transpose)\n\n  def multiply_cholesky_inverse(self, vector, transpose=False):\n    matrix = self._factor.get_cholesky_inverse(self._damping_func)\n    return self._multiply_matrix(matrix, vector, transpose=transpose)\n\n  def full_fisher_block(self):\n    \"\"\"Explicitly constructs the full Fisher block.\"\"\"\n    return self._factor.get_cov_as_linear_operator().to_dense()\n\n\nclass NaiveFullFB(FullFB):\n  \"\"\"FisherBlock using a full matrix estimate (no approximations).\n\n  NaiveFullFB uses a full matrix estimate (no approximations), and should only\n  ever be used for very low dimensional parameters.\n\n  Note that this uses the naive \"square the sum estimator\", and so is applicable\n  to any type of parameter in principle, but has very high variance.\n  \"\"\"\n\n  def __init__(self, layer_collection, params):\n    \"\"\"Creates a NaiveFullFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: The parameters of this layer (Tensor or tuple of Tensors).\n    \"\"\"\n    self._batch_sizes = []\n    self._params = params\n\n    super(NaiveFullFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n    self._damping_func = _package_func(lambda: damping, (damping,))\n\n    self._factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.NaiveFullFactor, (grads_list, self._batch_size))\n\n  def tensors_to_compute_grads(self):\n    return self._params\n\n  def register_additional_tower(self, batch_size):\n    \"\"\"Register an additional tower.\n\n    Args:\n      batch_size: The batch size, used in the covariance estimator.\n    \"\"\"\n    self._batch_sizes.append(batch_size)\n\n  @property\n  def num_registered_towers(self):\n    return len(self._batch_sizes)\n\n  @property\n  def _batch_size(self):\n    return tf.reduce_sum(self._batch_sizes)\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass DiagonalFB(FisherBlock):\n  \"\"\"A base class for FisherBlocks that use diagonal approximations.\"\"\"\n\n  def register_matpower(self, exp):\n    # Not needed for this.  Matrix powers are computed on demand in the\n    # diagonal case\n    pass\n\n  def register_cholesky(self):\n    # Not needed for this.  Cholesky's are computed on demand in the\n    # diagonal case\n    pass\n\n  def register_cholesky_inverse(self):\n    # Not needed for this.  Cholesky inverses's are computed on demand in the\n    # diagonal case\n    pass\n\n  def _multiply_matrix(self, matrix, vector):\n    vector_flat = utils.tensors_to_column(vector)\n    out_flat = matrix.matmul(vector_flat)\n    return utils.column_to_tensors(vector, out_flat)\n\n  def multiply_matpower(self, vector, exp):\n    matrix = self._factor.get_matpower(exp, self._damping_func)\n    return self._multiply_matrix(matrix, vector)\n\n  def multiply_cholesky(self, vector, transpose=False):\n    matrix = self._factor.get_cholesky(self._damping_func)\n    return self._multiply_matrix(matrix, vector)\n\n  def multiply_cholesky_inverse(self, vector, transpose=False):\n    matrix = self._factor.get_cholesky_inverse(self._damping_func)\n    return self._multiply_matrix(matrix, vector)\n\n  def full_fisher_block(self):\n    return self._factor.get_cov_as_linear_operator().to_dense()\n\n\nclass NaiveDiagonalFB(DiagonalFB):\n  \"\"\"FisherBlock using a diagonal matrix approximation.\n\n  This type of approximation is generically applicable but quite primitive.\n\n  Note that this uses the naive \"square the sum estimator\", and so is applicable\n  to any type of parameter in principle, but has very high variance.\n  \"\"\"\n\n  def __init__(self, layer_collection, params):\n    \"\"\"Creates a NaiveDiagonalFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: The parameters of this layer (must be a single Tensor).\n    \"\"\"\n    self._params = params\n    self._batch_sizes = []\n\n    super(NaiveDiagonalFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n    self._damping_func = _package_func(lambda: damping, (damping,))\n\n    self._factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))\n\n  def tensors_to_compute_grads(self):\n    return self._params\n\n  def register_additional_tower(self, batch_size):\n    \"\"\"Register an additional tower.\n\n    Args:\n      batch_size: The batch size, used in the covariance estimator.\n    \"\"\"\n    self._batch_sizes.append(batch_size)\n\n  @property\n  def num_registered_towers(self):\n    return len(self._batch_sizes)\n\n  @property\n  def _batch_size(self):\n    return tf.reduce_sum(self._batch_sizes)\n\n\nclass InputOutputMultiTower(object):\n  \"\"\"Mix-in class for blocks with inputs & outputs and multiple mini-batches.\"\"\"\n\n  def __init__(self, *args, **kwargs):\n    self.__inputs = []\n    self.__outputs = []\n    super(InputOutputMultiTower, self).__init__(*args, **kwargs)\n\n  def _process_data(self, grads_list):\n    \"\"\"Process data into the format used by the factors.\n\n    This function takes inputs and grads_lists data and processes it into\n    one of the formats expected by the FisherFactor classes (depending on\n    the value of the global configuration variable TOWER_STRATEGY).\n\n    The initial format of self._inputs is expected to be a list of Tensors\n    over towers. Similarly grads_lists is expected to be a list over sources\n    of such lists.\n\n    If TOWER_STRATEGY is \"concat\", 'inputs' becomes a tuple containing a single\n    tensor (represented as a PartitionedTensor object) equal to the\n    concatenation (across towers) of all of the elements of self._inputs. And\n    similarly grads_list is formatted into a tuple (over sources) of such\n    tensors (also represented as PartitionedTensors).\n\n    If TOWER_STRATEGY is \"separate\", formatting of inputs and grads_list\n    remains unchanged from the initial format (although possibly converting\n    from lists into tuples).\n\n    Args:\n      grads_list: grads_list in its initial format (see above).\n\n    Returns:\n      inputs: self._inputs transformed into the appropriate format (see\n        above).\n      grads_list: grads_list transformed into the appropriate format (see\n        above).\n\n    Raises:\n      ValueError: if TOWER_STRATEGY is not one of \"separate\" or \"concat\".\n    \"\"\"\n    inputs = self._inputs\n    # inputs is a list over towers of Tensors\n    # grads_list is a list of list with the first index being sources and the\n    # second being towers.\n    if fisher_factors.TOWER_STRATEGY == \"concat\":\n      # Merge towers together into a PartitionedTensor. We package it in\n      # a singleton tuple since the factors will expect a list over towers\n      inputs = (utils.PartitionedTensor(inputs),)\n      # Do the same for grads_list but preserve leading sources dimension\n      grads_list = tuple((utils.PartitionedTensor(grads),)\n                         for grads in grads_list)\n    elif fisher_factors.TOWER_STRATEGY == \"separate\":\n      inputs = tuple(inputs)\n      grads_list = tuple(grads_list)\n\n    else:\n      raise ValueError(\"Global config variable TOWER_STRATEGY must be one of \"\n                       \"'concat' or 'separate'.\")\n\n    return inputs, grads_list\n\n  def tensors_to_compute_grads(self):\n    \"\"\"Tensors to compute derivative of loss with respect to.\"\"\"\n    return tuple(self._outputs)\n\n  def register_additional_tower(self, inputs, outputs):\n    self._inputs.append(inputs)\n    self._outputs.append(outputs)\n\n  @property\n  def num_registered_towers(self):\n    result = len(self._inputs)\n    assert result == len(self._outputs)\n    return result\n\n  @property\n  def _inputs(self):\n    return self.__inputs\n\n  @property\n  def _outputs(self):\n    return self.__outputs\n\n\nclass FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):\n  \"\"\"FisherBlock for fully-connected (dense) layers using a diagonal approx.\n\n  Estimates the Fisher Information matrix's diagonal entries for a fully\n  connected layer. Unlike NaiveDiagonalFB this uses the low-variance \"sum of\n  squares\" estimator.\n\n  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index\n  into it. We are interested in Fisher(params)[i, i]. This is,\n\n    Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]\n                         = E[ v(x, y, params)[i] ^ 2 ]\n\n  Consider fully connected layer in this model with (unshared) weight matrix\n  'w'. For an example 'x' that produces layer inputs 'a' and output\n  preactivations 's',\n\n    v(x, y, w) = vec( a (d loss / d s)^T )\n\n  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding\n  to the layer's parameters 'w'.\n  \"\"\"\n\n  def __init__(self, layer_collection, has_bias=False):\n    \"\"\"Creates a FullyConnectedDiagonalFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      has_bias: bool. If True, estimates Fisher with respect to a bias\n        parameter as well as the layer's weights.\n          (Default: False)\n    \"\"\"\n    self._has_bias = has_bias\n\n    super(FullyConnectedDiagonalFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n    inputs, grads_list = self._process_data(grads_list)\n\n    self._factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.FullyConnectedDiagonalFactor,\n        (inputs, grads_list, self._has_bias))\n\n    self._damping_func = _package_func(lambda: damping, (damping,))\n\n\nclass ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):\n  \"\"\"FisherBlock for 2-D convolutional layers using a diagonal approx.\n\n  Estimates the Fisher Information matrix's diagonal entries for a convolutional\n  layer. Unlike NaiveDiagonalFB this uses the low-variance \"sum of squares\"\n  estimator.\n\n  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index\n  into it. We are interested in Fisher(params)[i, i]. This is,\n\n    Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]\n                         = E[ v(x, y, params)[i] ^ 2 ]\n\n  Consider a convolutional layer in this model with (unshared) filter matrix\n  'w'. For an example image 'x' that produces layer inputs 'a' and output\n  preactivations 's',\n\n    v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )\n\n  where 'loc' is a single (x, y) location in an image.\n\n  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding\n  to the layer's parameters 'w'.\n  \"\"\"\n\n  def __init__(self,\n               layer_collection,\n               params,\n               strides,\n               padding,\n               data_format=None,\n               dilations=None,\n               patch_mask=None):\n    \"\"\"Creates a ConvDiagonalFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: The parameters (Tensor or tuple of Tensors) of this layer. If\n        kernel alone, a Tensor of shape [kernel_height, kernel_width,\n        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements\n        containing the previous and a Tensor of shape [out_channels].\n      strides: The stride size in this layer (1-D Tensor of length 4).\n      padding: The padding in this layer (e.g. \"SAME\").\n      data_format: str or None. Format of input data.\n      dilations: List of 4 ints or None. Rate for dilation along all dimensions.\n      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]\n        or None. If not None this is multiplied against the extracted patches\n        Tensor (broadcasting along the batch dimension) before statistics are\n        computed. (Default: None)\n\n    Raises:\n      ValueError: if strides is not length-4.\n      ValueError: if dilations is not length-4.\n      ValueError: if channel is not last dimension.\n    \"\"\"\n    if len(strides) != 4:\n      raise ValueError(\"strides must contain 4 numbers.\")\n\n    if dilations is None:\n      dilations = [1, 1, 1, 1]\n\n    if len(dilations) != 4:\n      raise ValueError(\"dilations must contain 4 numbers.\")\n\n    if not utils.is_data_format_channel_last(data_format):\n      raise ValueError(\"data_format must be channels-last.\")\n\n    self._strides = maybe_tuple(strides)\n    self._padding = padding\n    self._data_format = data_format\n    self._dilations = maybe_tuple(dilations)\n    self._has_bias = isinstance(params, (tuple, list))\n\n    fltr = params[0] if self._has_bias else params\n    self._filter_shape = tuple(fltr.shape.as_list())\n\n    if len(self._filter_shape) != 4:\n      raise ValueError(\n          \"Convolution filter must be of shape\"\n          \" [filter_height, filter_width, in_channels, out_channels].\")\n\n    self._patch_mask = patch_mask\n\n    super(ConvDiagonalFB, self).__init__(layer_collection)\n\n  @property\n  def _factor_implementation(self):\n    return fisher_factors.ConvDiagonalFactor\n\n  def instantiate_factors(self, grads_list, damping):\n    inputs, grads_list = self._process_data(grads_list)\n\n    # Infer number of locations upon which convolution is applied.\n    self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(),\n                                                   list(self._filter_shape),\n                                                   self._strides,\n                                                   self._padding)\n\n    self._factor = self._layer_collection.make_or_get_factor(\n        self._factor_implementation,\n        (inputs, grads_list, self._filter_shape, self._strides, self._padding,\n         self._data_format, self._dilations, self._has_bias, self._patch_mask))\n\n    def damping_func():\n      return self._num_locations * normalize_damping(damping,\n                                                     self._num_locations)\n\n    damping_id = (self._num_locations, \"mult\", \"normalize_damping\", damping,\n                  self._num_locations)\n    self._damping_func = _package_func(damping_func, damping_id)\n\n\nclass ScaleAndShiftFullFB(InputOutputMultiTower, FullFB):\n  \"\"\"A FisherBlock class for scale and shift ops that uses no approximations.\n\n  This class estimates the same thing that NaiveFullFB would (when applied\n  to the scale and shift params), but with a lower variance estimator. In\n  particular it uses a \"sum the squares estimator\", and thus the variance will\n  shrink as 1/batch_size.\n  \"\"\"\n\n  def __init__(self, layer_collection, broadcast_dims_scale,\n               broadcast_dims_shift=None, has_shift=True):\n    \"\"\"Creates a ScaleAndShiftFullFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      broadcast_dims_scale: A list of dimension indices that are broadcast\n        along during the scale operation. Does not include batch dimension.\n      broadcast_dims_shift: A list of dimension indices that are broadcast\n        along during the shift operation. Does not include batch dimension.\n      has_shift: bool. If True, estimates Fisher with respect to a shift\n        parameter as well the scale parameter (which is always included).\n    \"\"\"\n    self._broadcast_dims_scale = broadcast_dims_scale\n    self._broadcast_dims_shift = broadcast_dims_shift\n    self._has_shift = has_shift\n\n    super(ScaleAndShiftFullFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n\n    inputs, grads_list = self._process_data(grads_list)\n\n    self._factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.ScaleAndShiftFullFactor,\n        (inputs, grads_list, self._broadcast_dims_scale,\n         self._broadcast_dims_shift, self._has_shift))\n\n    self._damping_func = _package_func(lambda: damping, (damping,))\n\n\nclass ScaleAndShiftDiagonalFB(InputOutputMultiTower, DiagonalFB):\n  \"\"\"A FisherBlock class for scale and shift ops that uses a diagonal approx.\n\n  This class estimates the same thing that NaiveDiagonalFB would (when applied\n  to the scale and shift params), but with a lower variance estimator. In\n  particular it uses a \"sum the squares estimator\", and thus the variance will\n  shrink as 1/batch_size.\n  \"\"\"\n\n  def __init__(self, layer_collection, broadcast_dims_scale,\n               broadcast_dims_shift=None, has_shift=True):\n    \"\"\"Creates a ScaleAndShiftDiagonalFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      broadcast_dims_scale: A list of dimension indices that are broadcast\n        along during the scale operation. Does not include batch dimension.\n      broadcast_dims_shift: A list of dimension indices that are broadcast\n        along during the shift operation. Does not include batch dimension.\n      has_shift: bool. If True, estimates Fisher with respect to a shift\n        parameter as well the scale parameter (which is always included).\n    \"\"\"\n    self._broadcast_dims_scale = broadcast_dims_scale\n    self._broadcast_dims_shift = broadcast_dims_shift\n    self._has_shift = has_shift\n\n    super(ScaleAndShiftDiagonalFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n\n    inputs, grads_list = self._process_data(grads_list)\n\n    self._factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.ScaleAndShiftDiagonalFactor,\n        (inputs, grads_list, self._broadcast_dims_scale,\n         self._broadcast_dims_shift, self._has_shift))\n\n    self._damping_func = _package_func(lambda: damping, (damping,))\n\n\nclass KroneckerProductFB(FisherBlock):\n  \"\"\"A base class for blocks with separate input and output Kronecker factors.\n\n  The Fisher block is approximated as a Kronecker product of the input and\n  output factors.\n  \"\"\"\n\n  def _setup_damping(self, damping, normalization=None):\n    \"\"\"Makes functions that compute the damping values for both factors.\"\"\"\n    def compute_damping():\n      if normalization is not None:\n        maybe_normalized_damping = normalize_damping(damping, normalization)\n      else:\n        maybe_normalized_damping = damping\n\n      return compute_pi_adjusted_damping(\n          self._input_factor.get_cov_as_linear_operator(),\n          self._output_factor.get_cov_as_linear_operator(),\n          maybe_normalized_damping**0.5)\n\n    if normalization is not None:\n      damping_id = (\"compute_pi_adjusted_damping\",\n                    \"cov\", self._input_factor.name,\n                    \"cov\", self._output_factor.name,\n                    \"normalize_damping\", damping, normalization, \"power\", 0.5)\n    else:\n      damping_id = (\"compute_pi_adjusted_damping\",\n                    \"cov\", self._input_factor.name,\n                    \"cov\", self._output_factor.name,\n                    damping, \"power\", 0.5)\n\n    self._input_damping_func = _package_func(lambda: compute_damping()[0],\n                                             damping_id + (\"ref\", 0))\n    self._output_damping_func = _package_func(lambda: compute_damping()[1],\n                                              damping_id + (\"ref\", 1))\n\n    # Also store the damping op for access to the effective damping later on,\n    # such as when writing to summary.\n    if normalization is not None:\n      self._damping = normalize_damping(damping, normalization)\n    else:\n      self._damping = damping\n\n  def register_matpower(self, exp):\n    self._input_factor.register_matpower(exp, self._input_damping_func)\n    self._output_factor.register_matpower(exp, self._output_damping_func)\n\n  def register_cholesky(self):\n    self._input_factor.register_cholesky(self._input_damping_func)\n    self._output_factor.register_cholesky(self._output_damping_func)\n\n  def register_cholesky_inverse(self):\n    self._input_factor.register_cholesky_inverse(self._input_damping_func)\n    self._output_factor.register_cholesky_inverse(self._output_damping_func)\n\n  @property\n  def damping(self):\n    \"\"\"A copy of the damping op.\n\n    This is not used (and should never be used) in KFAC computations. A valid\n    usage of this property could be to write damping values to the summary.\n\n    Returns:\n      0-D Tensor.\n    \"\"\"\n    return self._damping\n\n  @property\n  def input_factor(self):\n    return self._input_factor\n\n  @property\n  def output_factor(self):\n    return self._output_factor\n\n  @property\n  def _renorm_coeff(self):\n    \"\"\"Kronecker factor multiplier coefficient.\n\n    If this FisherBlock is represented as 'FB = c * kron(left, right)', then\n    this is 'c'.\n\n    Returns:\n      0-D Tensor.\n    \"\"\"\n    return 1.0\n\n  def _multiply_factored_matrix(self, left_factor, right_factor, vector,\n                                extra_scale=1.0, transpose_left=False,\n                                transpose_right=False):\n    \"\"\"Multiplies a factored matrix.\"\"\"\n    reshaped_vector = utils.layer_params_to_mat2d(vector)\n    reshaped_out = right_factor.matmul_right(reshaped_vector,\n                                             adjoint=transpose_right)\n    reshaped_out = left_factor.matmul(reshaped_out,\n                                      adjoint=transpose_left)\n    if extra_scale != 1.0:\n      reshaped_out = tf.scalar_mul(extra_scale, reshaped_out)\n    return utils.mat2d_to_layer_params(vector, reshaped_out)\n\n  def multiply_matpower(self, vector, exp):\n    left_factor = self._input_factor.get_matpower(\n        exp, self._input_damping_func)\n    right_factor = self._output_factor.get_matpower(\n        exp, self._output_damping_func)\n    extra_scale = float(self._renorm_coeff)**exp\n    return self._multiply_factored_matrix(left_factor, right_factor, vector,\n                                          extra_scale=extra_scale)\n\n  def multiply_cholesky(self, vector, transpose=False):\n    left_factor = self._input_factor.get_cholesky(self._input_damping_func)\n    right_factor = self._output_factor.get_cholesky(self._output_damping_func)\n    extra_scale = float(self._renorm_coeff)**0.5\n    return self._multiply_factored_matrix(left_factor, right_factor, vector,\n                                          extra_scale=extra_scale,\n                                          transpose_left=transpose,\n                                          transpose_right=not transpose)\n\n  def multiply_cholesky_inverse(self, vector, transpose=False):\n    left_factor = self._input_factor.get_cholesky_inverse(\n        self._input_damping_func)\n    right_factor = self._output_factor.get_cholesky_inverse(\n        self._output_damping_func)\n    extra_scale = float(self._renorm_coeff)**-0.5\n    return self._multiply_factored_matrix(left_factor, right_factor, vector,\n                                          extra_scale=extra_scale,\n                                          transpose_left=transpose,\n                                          transpose_right=not transpose)\n\n  def full_fisher_block(self):\n    \"\"\"Explicitly constructs the full Fisher block.\n\n    Used for testing purposes. (In general, the result may be very large.)\n\n    Returns:\n      The full Fisher block.\n    \"\"\"\n    left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()\n    right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()\n    return self._renorm_coeff * utils.kronecker_product(left_factor,\n                                                        right_factor)\n\n\nclass FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):\n  \"\"\"K-FAC FisherBlock for fully-connected (dense) layers.\n\n  This uses the Kronecker-factorized approximation from the original\n  K-FAC paper (https://arxiv.org/abs/1503.05671)\n  \"\"\"\n\n  def __init__(self, layer_collection, has_bias=False,\n               diagonal_approx_for_input=False,\n               diagonal_approx_for_output=False):\n    \"\"\"Creates a FullyConnectedKFACBasicFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      has_bias: bool. If True, estimates Fisher with respect to a bias\n        parameter as well as the layer's weights.\n        (Default: False)\n      diagonal_approx_for_input: Whether to use diagonal approximation for the\n        input Kronecker factor. (Default: False)\n      diagonal_approx_for_output: Whether to use diagonal approximation for the\n        output Kronecker factor. (Default: False)\n    \"\"\"\n    self._has_bias = has_bias\n    self._diagonal_approx_for_input = diagonal_approx_for_input\n    self._diagonal_approx_for_output = diagonal_approx_for_output\n\n    super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n    \"\"\"Instantiate Kronecker Factors for this FisherBlock.\n\n    Args:\n      grads_list: List of list of Tensors. grads_list[i][j] is the\n        gradient of the loss with respect to 'outputs' from source 'i' and\n        tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].\n      damping: 0-D Tensor or float. 'damping' * identity is approximately added\n        to this FisherBlock's Fisher approximation.\n    \"\"\"\n    inputs, grads_list = self._process_data(grads_list)\n\n    if self._diagonal_approx_for_input:\n      self._input_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.DiagonalKroneckerFactor,\n          ((inputs,), self._has_bias))\n    else:\n      self._input_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.FullyConnectedKroneckerFactor,\n          ((inputs,), self._has_bias))\n\n    if self._diagonal_approx_for_output:\n      self._output_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.DiagonalKroneckerFactor,\n          (grads_list,))\n    else:\n      self._output_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.FullyConnectedKroneckerFactor,\n          (grads_list,))\n\n    self._setup_damping(damping)\n\n\nclass ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):\n  \"\"\"FisherBlock for convolutional layers using the basic KFC approx.\n\n  Estimates the Fisher Information matrix's blog for a convolutional\n  layer.\n\n  Consider a convolutional layer in this model with (unshared) filter matrix\n  'w'. For a minibatch that produces inputs 'a' and output preactivations 's',\n  this FisherBlock estimates,\n\n    F(w) = #locations * kronecker(E[flat(a) flat(a)^T],\n                                  E[flat(ds) flat(ds)^T])\n\n  where\n\n    ds = (d / ds) log p(y | x, w)\n    #locations = number of (x, y) locations where 'w' is applied.\n\n  where the expectation is taken over all examples and locations and flat()\n  concatenates an array's leading dimensions.\n\n  See equation 23 in https://arxiv.org/abs/1602.01407 for details.\n  \"\"\"\n\n  def __init__(self,\n               layer_collection,\n               params,\n               padding,\n               strides=None,\n               dilation_rate=None,\n               data_format=None,\n               extract_patches_fn=None,\n               sub_sample_inputs=None,\n               sub_sample_patches=None,\n               use_sua_approx_for_input_factor=False,\n               patch_mask=None):\n    \"\"\"Creates a ConvKFCBasicFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: The parameters (Tensor or tuple of Tensors) of this layer. If\n        kernel alone, a Tensor of shape [..spatial_filter_shape..,\n        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements\n        containing the previous and a Tensor of shape [out_channels].\n      padding: str. Padding method.\n      strides: List of ints or None. Contains [..spatial_filter_strides..] if\n        'extract_patches_fn' is compatible with tf.nn.convolution(), else\n        [1, ..spatial_filter_strides, 1].\n      dilation_rate: List of ints or None. Rate for dilation along each spatial\n        dimension if 'extract_patches_fn' is compatible with\n        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].\n      data_format: str or None. Format of input data.\n      extract_patches_fn: str or None. Name of function that extracts image\n        patches. One of \"extract_convolution_patches\", \"extract_image_patches\",\n        \"extract_pointwise_conv2d_patches\".\n      sub_sample_inputs: `bool`. If True, then subsample the inputs from which\n        the image patches are extracted. (Default: None)\n      sub_sample_patches: `bool`, If `True` then subsample the extracted\n        patches. (Default: None)\n      use_sua_approx_for_input_factor: `bool`, If `True` then use\n        `ConvInputSUAKroneckerFactor` for input factor. Otherwise use\n        `ConvInputKroneckerFactor`. (Default: None)\n      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]\n        or None. If not None this is multiplied against the extracted patches\n        Tensor (broadcasting along the batch dimension) before statistics are\n        computed in the input factor. (Default: None)\n    \"\"\"\n    self._padding = padding\n    self._strides = maybe_tuple(strides)\n    self._dilation_rate = maybe_tuple(dilation_rate)\n    self._data_format = data_format\n    self._extract_patches_fn = extract_patches_fn\n    self._has_bias = isinstance(params, (tuple, list))\n    self._use_sua_approx_for_input_factor = use_sua_approx_for_input_factor\n\n    fltr = params[0] if self._has_bias else params\n    self._filter_shape = tuple(fltr.shape.as_list())\n\n    self._sub_sample_inputs = sub_sample_inputs\n    self._sub_sample_patches = sub_sample_patches\n    self._patch_mask = patch_mask\n\n    super(ConvKFCBasicFB, self).__init__(layer_collection)\n\n  def instantiate_factors(self, grads_list, damping):\n    inputs, grads_list = self._process_data(grads_list)\n\n    # Infer number of locations upon which convolution is applied.\n    self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(),\n                                                   list(self._filter_shape),\n                                                   self._strides,\n                                                   self._padding)\n\n    if self._use_sua_approx_for_input_factor:\n      self._input_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.ConvInputSUAKroneckerFactor,\n          (inputs, self._filter_shape, self._has_bias))\n    else:\n      self._input_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.ConvInputKroneckerFactor,\n          (inputs, self._filter_shape, self._padding, self._strides,\n           self._dilation_rate, self._data_format, self._extract_patches_fn,\n           self._has_bias, self._sub_sample_inputs, self._sub_sample_patches,\n           self._patch_mask))\n\n    self._output_factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.ConvOutputKroneckerFactor, (grads_list,))\n\n    self._setup_damping(damping, normalization=self._num_locations)\n\n  @property\n  def _renorm_coeff(self):\n    return self._num_locations\n\n\nclass DepthwiseConvDiagonalFB(ConvDiagonalFB):\n  \"\"\"FisherBlock for depthwise_conv2d().\n\n  Equivalent to ConvDiagonalFB applied to each input channel in isolation.\n  \"\"\"\n\n  def __init__(self,\n               layer_collection,\n               params,\n               strides,\n               padding,\n               rate=None,\n               data_format=None):\n    \"\"\"Creates a DepthwiseConvKFCBasicFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: Tensor of shape [filter_height, filter_width, in_channels,\n        channel_multiplier].\n      strides: List of 4 ints. Strides along all dimensions.\n      padding: str. Padding method.\n      rate: List of 4 ints or None. Rate for dilation along all dimensions.\n      data_format: str or None. Format of input data.\n\n    Raises:\n      NotImplementedError: If parameters contains bias.\n      ValueError: If filter is not 4-D.\n      ValueError: If strides is not length-4.\n      ValueError: If rates is not length-2.\n      ValueError: If channels are not last dimension.\n    \"\"\"\n    if isinstance(params, (tuple, list)):\n      raise NotImplementedError(\"Bias not yet supported.\")\n\n    if params.shape.ndims != 4:\n      raise ValueError(\"Filter must be 4-D.\")\n\n    if len(strides) != 4:\n      raise ValueError(\"strides must account for 4 dimensions.\")\n\n    if rate is not None:\n      if len(rate) != 2:\n        raise ValueError(\"rate must only account for spatial dimensions.\")\n      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.\n\n    if not utils.is_data_format_channel_last(data_format):\n      raise ValueError(\"data_format must be channels-last.\")\n\n    super(DepthwiseConvDiagonalFB, self).__init__(\n        layer_collection=layer_collection,\n        params=params,\n        strides=strides,\n        padding=padding,\n        dilations=rate,\n        data_format=data_format)\n\n    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().\n    filter_height, filter_width, in_channels, channel_multiplier = (\n        params.shape.as_list())\n    self._filter_shape = (filter_height, filter_width, in_channels,\n                          in_channels * channel_multiplier)\n\n  def _multiply_matrix(self, matrix, vector):\n    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)\n    conv2d_result = super(\n        DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)\n    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)\n\n\nclass DepthwiseConvKFCBasicFB(ConvKFCBasicFB):\n  \"\"\"FisherBlock for depthwise_conv2d().\n\n  Equivalent to ConvKFCBasicFB applied to each input channel in isolation.\n  \"\"\"\n\n  def __init__(self,\n               layer_collection,\n               params,\n               strides,\n               padding,\n               rate=None,\n               data_format=None):\n    \"\"\"Creates a DepthwiseConvKFCBasicFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: Tensor of shape [filter_height, filter_width, in_channels,\n        channel_multiplier].\n      strides: List of 4 ints. Strides along all dimensions.\n      padding: str. Padding method.\n      rate: List of 4 ints or None. Rate for dilation along all dimensions.\n      data_format: str or None. Format of input data.\n\n    Raises:\n      NotImplementedError: If parameters contains bias.\n      ValueError: If filter is not 4-D.\n      ValueError: If strides is not length-4.\n      ValueError: If rates is not length-2.\n      ValueError: If channels are not last dimension.\n    \"\"\"\n    if isinstance(params, (tuple, list)):\n      raise NotImplementedError(\"Bias not yet supported.\")\n\n    if params.shape.ndims != 4:\n      raise ValueError(\"Filter must be 4-D.\")\n\n    if len(strides) != 4:\n      raise ValueError(\"strides must account for 4 dimensions.\")\n\n    if rate is not None:\n      if len(rate) != 2:\n        raise ValueError(\"rate must only account for spatial dimensions.\")\n      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.\n\n    if not utils.is_data_format_channel_last(data_format):\n      raise ValueError(\"data_format must be channels-last.\")\n\n    super(DepthwiseConvKFCBasicFB, self).__init__(\n        layer_collection=layer_collection,\n        params=params,\n        padding=padding,\n        strides=strides,\n        dilation_rate=rate,\n        data_format=data_format,\n        extract_patches_fn=\"extract_image_patches\")\n\n    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().\n    filter_height, filter_width, in_channels, channel_multiplier = (\n        params.shape.as_list())\n    self._filter_shape = (filter_height, filter_width, in_channels,\n                          in_channels * channel_multiplier)\n\n  def _multiply_factored_matrix(self, left_factor, right_factor, vector,\n                                extra_scale=1.0, transpose_left=False,\n                                transpose_right=False):\n    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)\n    conv2d_result = super(\n        DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(\n            left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,\n            transpose_left=transpose_left, transpose_right=transpose_right)\n    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)\n\n\ndef depthwise_conv2d_filter_to_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin\n  \"\"\"Converts a convolution filter for use with conv2d.\n\n  Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's\n  compatible with tf.nn.conv2d().\n\n  Args:\n    filter: Tensor of shape [height, width, in_channels, channel_multiplier].\n    name: None or str. Name of Op.\n\n  Returns:\n    Tensor of shape [height, width, in_channels, out_channels].\n\n  \"\"\"\n  with tf.name_scope(name, \"depthwise_conv2d_filter_to_conv2d_filter\",\n                     [filter]):\n    filter = tf.convert_to_tensor(filter)\n    filter_height, filter_width, in_channels, channel_multiplier = (\n        filter.shape.as_list())\n\n    results = []\n    for i in range(in_channels):\n      # Slice out one in_channel's filter. Insert zeros around it to force it\n      # to affect that channel and that channel alone.\n      elements = []\n      if i > 0:\n        elements.append(\n            tf.zeros([filter_height, filter_width, i, channel_multiplier]))\n      elements.append(filter[:, :, i:(i + 1), :])\n      if i + 1 < in_channels:\n        elements.append(\n            tf.zeros([\n                filter_height, filter_width, in_channels - (i + 1),\n                channel_multiplier\n            ]))\n\n      # Concat along in_channel.\n      results.append(tf.concat(elements, axis=-2, name=\"in_channel_%d\" % i))\n\n    # Concat along out_channel.\n    return tf.concat(results, axis=-1, name=\"out_channel\")\n\n\ndef conv2d_filter_to_depthwise_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin\n  \"\"\"Converts a convolution filter for use with depthwise_conv2d.\n\n  Transforms a filter for use with tf.nn.conv2d() to one that's\n  compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along\n  the diagonal.\n\n  Args:\n    filter: Tensor of shape [height, width, in_channels, out_channels].\n    name: None or str. Name of Op.\n\n  Returns:\n    Tensor of shape,\n      [height, width, in_channels, channel_multiplier]\n\n  Raises:\n    ValueError: if out_channels is not evenly divisible by in_channels.\n  \"\"\"\n  with tf.name_scope(name, \"conv2d_filter_to_depthwise_conv2d_filter\",\n                     [filter]):\n    filter = tf.convert_to_tensor(filter)\n    filter_height, filter_width, in_channels, out_channels = (\n        filter.shape.as_list())\n\n    if out_channels % in_channels != 0:\n      raise ValueError(\"out_channels must be evenly divisible by in_channels.\")\n    channel_multiplier = out_channels // in_channels\n\n    results = []\n    filter = tf.reshape(filter, [\n        filter_height, filter_width, in_channels, in_channels,\n        channel_multiplier\n    ])\n    for i in range(in_channels):\n      # Slice out output corresponding to the correct filter.\n      filter_slice = tf.reshape(\n          filter[:, :, i, i, :],\n          [filter_height, filter_width, 1, channel_multiplier])\n      results.append(filter_slice)\n\n    # Concat along out_channel.\n    return tf.concat(results, axis=-2, name=\"in_channels\")\n\n\ndef maybe_tuple(obj):\n  if not isinstance(obj, list):\n    return obj\n  return tuple(obj)\n\n\nclass InputOutputMultiTowerMultiUse(InputOutputMultiTower):\n  \"\"\"Adds methods for multi-use/time-step case to InputOutputMultiTower.\"\"\"\n\n  def __init__(self, num_uses=None, *args, **kwargs):\n    self._num_uses = num_uses\n    super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)\n\n  def _process_data(self, grads_list):\n    \"\"\"Process temporal/multi-use data into the format used by the factors.\n\n    This function takes inputs and grads_lists data and processes it into\n    one of the formats expected by the FisherFactor classes (depending on\n    the value of the global configuration variable TOWER_STRATEGY).\n\n    It accepts the data in one of two initial formats. The first possible\n    format is where self._inputs is a list of list of Tensors. The first index\n    is tower, the second is use/time-step. grads_list, meanwhile, is a list\n    over sources of such lists of lists.\n\n    The second possible data format is where self._inputs is a list of Tensors\n    (over towers), where each tensor either has shape\n    [num_uses, batch_size, ...] or each tensor has shape\n    [num_uses*batch_size, ...] (which is formed by reshaping tensors of the\n    first format). And similarly grads_list is a list over sources of such lists\n    of Tensors.\n\n    There are two possible formats which inputs and grads_list are transformed\n    into.\n\n    If TOWER_STRATEGY is \"concat\", 'inputs' becomes a tuple containing\n    a single tensor (represented as a PartitionedTensor object) with all of\n    the data from the towers, as well as the uses/time-steps, concatenated\n    together. The format of this tensor is the same as the second input data\n    format above. Similarly, grads_list is a tuple over sources of such\n    lists of tensors.\n\n    If TOWER_STRATEGY is \"separate\" the inputs are formatted into lists of\n    tensors over towers. Each of these tensors has a similar format to\n    the tensor produced by the \"concat\" option, except that each contains\n    only the data from a single tower. grads_list is similarly formatted\n    into a tuple over sources of such tuples.\n\n    Args:\n      grads_list: grads_list in its initial format (see above).\n\n    Returns:\n      inputs: self._inputs transformed into the appropriate format (see\n        above).\n      grads_list: grads_list transformed into the appropriate format (see\n        above).\n\n    Raises:\n      ValueError: If TOWER_STRATEGY is not one of \"separate\" or \"concat\".\n      ValueError: If the given/initial format of self._inputs and grads_list\n        isn't recognized, or doesn't agree with self._num_uses.\n    \"\"\"\n    inputs = self._inputs\n\n    # The first data format.\n    if isinstance(inputs[0], (list, tuple)):\n\n      num_uses = len(inputs[0])\n\n      if self._num_uses is not None and self._num_uses != num_uses:\n        raise ValueError(\"num_uses argument doesn't match length of inputs.\")\n      else:\n        self._num_uses = num_uses\n\n      # Check that all mini-batches/towers have the same number of uses\n      if not all(len(input_) == num_uses for input_ in inputs):\n        raise ValueError(\"Length of inputs argument is inconsistent across \"\n                         \"towers.\")\n\n      if fisher_factors.TOWER_STRATEGY == \"concat\":\n        # Reverse the tower and use/time-step indices, so that use is now first,\n        # and towers is second\n        inputs = tuple(zip(*inputs))\n\n        # Flatten the two dimensions\n        inputs = nest.flatten(inputs)\n\n        # Merge everything together into a PartitionedTensor. We package it in\n        # a singleton tuple since the factors will expect a list over towers\n        inputs = (utils.PartitionedTensor(inputs),)\n\n      elif fisher_factors.TOWER_STRATEGY == \"separate\":\n        # Merge together the uses/time-step dimension into PartitionedTensors,\n        # but keep the leading dimension (towers) intact for the factors to\n        # process individually.\n        inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)\n\n      else:\n        raise ValueError(\"Global config variable TOWER_STRATEGY must be one of \"\n                         \"'concat' or 'separate'.\")\n    # The second data format\n    else:\n      inputs = tuple(inputs)\n\n    # Now we perform the analogous processing for grads_list\n\n    # The first data format.\n    if isinstance(grads_list[0][0], (list, tuple)):\n\n      num_uses = len(grads_list[0][0])\n\n      if self._num_uses is not None and self._num_uses != num_uses:\n        raise ValueError(\"num_uses argument doesn't match length of outputs, \"\n                         \"or length of outputs is inconsistent with length of \"\n                         \"inputs.\")\n      else:\n        self._num_uses = num_uses\n\n      if not all(len(grad) == num_uses for grads in grads_list\n                 for grad in grads):\n        raise ValueError(\"Length of outputs argument is inconsistent across \"\n                         \"towers.\")\n\n      if fisher_factors.TOWER_STRATEGY == \"concat\":\n        # Reverse the tower and use/time-step indices, so that use is now first,\n        # and towers is second\n        grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)\n\n        # Flatten the two dimensions, leaving the leading dimension (source)\n        # intact\n        grads_list = tuple(nest.flatten(grads) for grads in grads_list)\n\n        # Merge inner dimensions together into PartitionedTensors. We package\n        # them in a singleton tuple since the factors will expect a list over\n        # towers\n        grads_list = tuple((utils.PartitionedTensor(grads),)\n                           for grads in grads_list)\n\n      elif fisher_factors.TOWER_STRATEGY == \"separate\":\n        # Merge together the uses/time-step dimension into PartitionedTensors,\n        # but keep the leading dimension (towers) intact for the factors to\n        # process individually.\n        grads_list = tuple(tuple(utils.PartitionedTensor(grad)\n                                 for grad in grads)\n                           for grads in grads_list)\n\n      else:\n        raise ValueError(\"Global config variable TOWER_STRATEGY must be one of \"\n                         \"'concat' or 'separate'.\")\n\n    # The second data format.\n    else:\n      grads_list = tuple(tuple(grads) for grads in grads_list)\n\n    if self._num_uses is None:\n      raise ValueError(\"You must supply a value for the num_uses argument if \"\n                       \"the number of uses cannot be inferred from inputs or \"\n                       \"outputs arguments (e.g. if they are both given in the \"\n                       \"single Tensor format, instead of as lists of Tensors.\")\n\n    return inputs, grads_list\n\n\nclass FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,\n                                 KroneckerProductFB):\n  \"\"\"FisherBlock for fully-connected layers that share parameters.\n\n  This class implements the \"independence across time\" approximation from the\n  following paper:\n    https://openreview.net/pdf?id=HyMTkQZAb\n  \"\"\"\n\n  def __init__(self, layer_collection, has_bias=False, num_uses=None,\n               diagonal_approx_for_input=False,\n               diagonal_approx_for_output=False):\n    \"\"\"Creates a FullyConnectedMultiIndepFB block.\n\n    Args:\n      layer_collection: LayerCollection instance.\n      has_bias: bool. If True, estimates Fisher with respect to a bias\n        parameter as well as the layer's weights. (Default: False)\n      num_uses: int or None. Number of uses of the layer in the model's graph.\n        Only required if the data is formatted with uses/time folded into the\n        batch dimension (instead of uses/time being a list dimension).\n        (Default: None)\n      diagonal_approx_for_input: Whether to use diagonal approximation for the\n        input Kronecker factor. (Default: False)\n      diagonal_approx_for_output: Whether to use diagonal approximation for the\n        output Kronecker factor. (Default: False)\n    \"\"\"\n    self._has_bias = has_bias\n    self._diagonal_approx_for_input = diagonal_approx_for_input\n    self._diagonal_approx_for_output = diagonal_approx_for_output\n\n    super(FullyConnectedMultiIndepFB, self).__init__(\n        layer_collection=layer_collection,\n        num_uses=num_uses)\n\n  def instantiate_factors(self, grads_list, damping):\n    inputs, grads_list = self._process_data(grads_list)\n\n    if self._diagonal_approx_for_input:\n      self._input_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.DiagonalMultiKF,\n          ((inputs,), self._num_uses, self._has_bias))\n    else:\n      self._input_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.FullyConnectedMultiKF,\n          ((inputs,), self._num_uses, self._has_bias))\n\n    if self._diagonal_approx_for_output:\n      self._output_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.DiagonalMultiKF,\n          (grads_list, self._num_uses))\n    else:\n      self._output_factor = self._layer_collection.make_or_get_factor(\n          fisher_factors.FullyConnectedMultiKF,\n          (grads_list, self._num_uses))\n\n    self._setup_damping(damping, normalization=self._num_uses)\n\n  @property\n  def _renorm_coeff(self):\n    return float(self._num_uses)\n\n\nclass ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,\n                               KroneckerProductFB):\n  \"\"\"FisherBlock for 2D convolutional layers using the basic KFC approx.\n\n  Similar to ConvKFCBasicFB except that this version supports multiple\n  uses/time-steps via a standard independence approximation.  Similar to the\n  \"independence across time\" used in FullyConnectedMultiIndepFB but generalized\n  in the obvious way to conv layers.\n  \"\"\"\n\n  def __init__(self,\n               layer_collection,\n               params,\n               padding,\n               strides=None,\n               dilation_rate=None,\n               data_format=None,\n               extract_patches_fn=None,\n               num_uses=None):\n    \"\"\"Creates a ConvKFCBasicMultiIndepFB block.\n\n    Args:\n      layer_collection: The LayerCollection object which owns this block.\n      params: The parameters (Tensor or tuple of Tensors) of this layer. If\n        kernel alone, a Tensor of shape [..spatial_filter_shape..,\n        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements\n        containing the previous and a Tensor of shape [out_channels].\n      padding: str. Padding method.\n      strides: List of ints or None. Contains [..spatial_filter_strides..] if\n        'extract_patches_fn' is compatible with tf.nn.convolution(), else\n        [1, ..spatial_filter_strides, 1].\n      dilation_rate: List of ints or None. Rate for dilation along each spatial\n        dimension if 'extract_patches_fn' is compatible with\n        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].\n      data_format: str or None. Format of input data.\n      extract_patches_fn: str or None. Name of function that extracts image\n        patches. One of \"extract_convolution_patches\", \"extract_image_patches\",\n        \"extract_pointwise_conv2d_patches\".\n      num_uses: int or None. Number of uses of the layer in the model's graph.\n        Only required if the data is formatted with uses/time folded into the\n        batch dimension (instead of uses/time being a list dimension).\n        (Default: None)\n    \"\"\"\n    self._padding = padding\n    self._strides = maybe_tuple(strides)\n    self._dilation_rate = maybe_tuple(dilation_rate)\n    self._data_format = data_format\n    self._extract_patches_fn = extract_patches_fn\n    self._has_bias = isinstance(params, (tuple, list))\n\n    fltr = params[0] if self._has_bias else params\n    self._filter_shape = tuple(fltr.shape.as_list())\n\n    super(ConvKFCBasicMultiIndepFB, self).__init__(\n        layer_collection=layer_collection,\n        num_uses=num_uses)\n\n  def instantiate_factors(self, grads_list, damping):\n    inputs, grads_list = self._process_data(grads_list)\n\n    # Infer number of locations upon which convolution is applied.\n    self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(),\n                                                   list(self._filter_shape),\n                                                   self._strides,\n                                                   self._padding)\n\n    self._input_factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.ConvInputMultiKF,\n        (inputs, self._filter_shape, self._padding, self._num_uses,\n         self._strides, self._dilation_rate, self._data_format,\n         self._extract_patches_fn, self._has_bias, self._num_uses))\n    self._output_factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.ConvOutputMultiKF, (grads_list, self._num_uses,\n                                           self._data_format))\n\n    self._setup_damping(damping,\n                        normalization=(self._num_locations * self._num_uses))\n\n  @property\n  def _renorm_coeff(self):\n    return self._num_locations * self._num_uses\n\n\nclass SeriesFBApproximation(object):\n  \"\"\"See FullyConnectedSeriesFB.__init__ for description and usage.\"\"\"\n  option1 = 1\n  option2 = 2\n\n\nclass FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,\n                             KroneckerProductFB):\n  \"\"\"FisherBlock for fully-connected layers that share parameters across time.\n\n  This class implements the \"Option 1\" and \"Option 2\" approximation from the\n  following paper:\n    https://openreview.net/pdf?id=HyMTkQZAb\n\n  See the end of the appendix of the paper for a pseudo-code of the\n  algorithm being implemented by multiply_matpower here.  Note that we are\n  using pre-computed versions of certain matrix-matrix products to speed\n  things up.  This is explicitly explained wherever it is done.\n  \"\"\"\n\n  def __init__(self,\n               layer_collection,\n               has_bias=False,\n               num_uses=None,\n               option=SeriesFBApproximation.option2):\n    \"\"\"Constructs a new `FullyConnectedSeriesFB`.\n\n    Args:\n      layer_collection: The collection of all layers in the K-FAC approximate\n        Fisher information matrix to which this FisherBlock belongs.\n      has_bias: bool. If True, estimates Fisher with respect to a bias\n        parameter as well as the layer's weights.\n      num_uses: int or None. Number of time-steps over which the layer\n        is used. Only required if the data is formatted with time folded into\n        the batch dimension (instead of time being a list dimension).\n        (Default: None)\n      option: A `SeriesFBApproximation` specifying the simplifying assumption\n        to be used in this block. `option1` approximates the cross-covariance\n        over time as a symmetric matrix, while `option2` makes\n        the assumption that training sequences are infinitely long. See section\n        3.5 of the paper for more details.\n    \"\"\"\n\n    self._has_bias = has_bias\n    self._option = option\n\n    super(FullyConnectedSeriesFB, self).__init__(\n        layer_collection=layer_collection,\n        num_uses=num_uses)\n\n  @property\n  def _num_timesteps(self):\n    return self._num_uses\n\n  @property\n  def _renorm_coeff(self):\n    # This should no longer be used since the multiply_X functions from the base\n    # class have been overridden\n    assert False\n\n  def instantiate_factors(self, grads_list, damping):\n    inputs, grads_list = self._process_data(grads_list)\n\n    self._input_factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.FullyConnectedMultiKF,\n        ((inputs,), self._num_uses, self._has_bias))\n    self._input_factor.register_cov_dt1()\n\n    self._output_factor = self._layer_collection.make_or_get_factor(\n        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))\n    self._output_factor.register_cov_dt1()\n\n    self._setup_damping(damping, normalization=self._num_uses)\n\n  def register_matpower(self, exp):\n    if exp != -1:\n      raise NotImplementedError(\"FullyConnectedSeriesFB only supports inverse\"\n                                \"multiplications.\")\n\n    if self._option == SeriesFBApproximation.option1:\n      self._input_factor.register_option1quants(self._input_damping_func)\n      self._output_factor.register_option1quants(self._output_damping_func)\n    elif self._option == SeriesFBApproximation.option2:\n      self._input_factor.register_option2quants(self._input_damping_func)\n      self._output_factor.register_option2quants(self._output_damping_func)\n    else:\n      raise ValueError(\n          \"Unrecognized FullyConnectedSeriesFB approximation: {}\".format(\n              self._option))\n\n  def multiply_matpower(self, vector, exp):\n    if exp != -1:\n      raise NotImplementedError(\"FullyConnectedSeriesFB only supports inverse\"\n                                \"multiplications.\")\n\n    # pylint: disable=invalid-name\n\n    Z = utils.layer_params_to_mat2d(vector)\n\n    # Derivations were done for \"batch_dim==1\" case so we need to convert to\n    # that orientation:\n    Z = tf.transpose(Z)\n\n    if self._option == SeriesFBApproximation.option1:\n\n      # Note that L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\n      L_A, psi_A = self._input_factor.get_option1quants(\n          self._input_damping_func)\n      L_G, psi_G = self._output_factor.get_option1quants(\n          self._output_damping_func)\n\n      def gamma(x):\n        # We are assuming that each case has the same number of time-steps.\n        # If this stops being the case one shouldn't simply replace this T\n        # with its average value.  Instead, one needs to go back to the\n        # definition of the gamma function from the paper.\n        T = self._num_timesteps\n        return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))\n\n      # Y = \\gamma( psi_G*psi_A^T ) (computed element-wise)\n      # Even though Y is Z-independent we are recomputing it from the psi's\n      # each since Y depends on both A and G quantities, and it is relatively\n      # cheap to compute.\n      Y = gamma(tf.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)\n\n      # Z = L_G^T * Z * L_A\n      # This is equivalent to the following computation from the original\n      # pseudo-code:\n      # Z = G0^{-1/2} * Z * A0^{-1/2}\n      # Z = U_G^T * Z * U_A\n      Z = tf.matmul(L_G, tf.matmul(Z, L_A), transpose_a=True)\n\n      # Z = Z .* Y\n      Z *= Y\n\n      # Z = L_G * Z * L_A^T\n      # This is equivalent to the following computation from the original\n      # pseudo-code:\n      # Z = U_G * Z * U_A^T\n      # Z = G0^{-1/2} * Z * A0^{-1/2}\n      Z = tf.matmul(L_G, tf.matmul(Z, L_A, transpose_b=True))\n\n    elif self._option == SeriesFBApproximation.option2:\n\n      # Note that P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1},\n      # and K_A = A_0^{-1/2} * E_A\\ and\\ K_G = G_0^{-1/2} * E_G.\n      P_A, K_A, mu_A = self._input_factor.get_option2quants(\n          self._input_damping_func)\n      P_G, K_G, mu_G = self._output_factor.get_option2quants(\n          self._output_damping_func)\n\n      # Our approach differs superficially from the pseudo-code in the paper\n      # in order to reduce the total number of matrix-matrix multiplies.\n      # In particular, the first three computations in the pseudo code are\n      # Z = G0^{-1/2} * Z * A0^{-1/2}\n      # Z = Z - hPsi_G^T * Z * hPsi_A\n      # Z = E_G^T * Z * E_A\n      # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}, so that\n      # C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\n      # the entire computation can be written as\n      # Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\n      #     - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\n      #   = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\n      #     - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\n      #   = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\n      #     -  E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\n      #   = K_G^T * Z * K_A  -  K_G^T * P_G * Z * P_A^T * K_A\n      # This final expression is computed by the following two lines:\n      # Z = Z - P_G * Z * P_A^T\n      Z -= tf.matmul(P_G, tf.matmul(Z, P_A, transpose_b=True))\n      # Z = K_G^T * Z * K_A\n      Z = tf.matmul(K_G, tf.matmul(Z, K_A), transpose_a=True)\n\n      # Z = Z ./ (1*1^T - mu_G*mu_A^T)\n      # Be careful with the outer product.  We don't want to accidentally\n      # make it an inner-product instead.\n      tmp = 1.0 - tf.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A\n      # Prevent some numerical issues by setting any 0.0 eigs to 1.0\n      tmp += 1.0 * tf.cast(tf.equal(tmp, 0.0), dtype=tmp.dtype)\n      Z /= tmp\n\n      # We now perform the transpose/reverse version of the operations\n      # derived above, whose derivation from the original pseudo-code is\n      # analgous.\n      # Z = K_G * Z * K_A^T\n      Z = tf.matmul(K_G, tf.matmul(Z, K_A, transpose_b=True))\n\n      # Z = Z - P_G^T * Z * P_A\n      Z -= tf.matmul(P_G, tf.matmul(Z, P_A), transpose_a=True)\n\n      # Z = normalize (1/E[T]) * Z\n      # Note that this normalization is done because we compute the statistics\n      # by averaging, not summing, over time. (And the gradient is presumably\n      # summed over time, not averaged, and thus their scales are different.)\n      Z /= tf.cast(self._num_timesteps, Z.dtype)\n\n    # Convert back to the \"batch_dim==0\" orientation.\n    Z = tf.transpose(Z)\n\n    return utils.mat2d_to_layer_params(vector, Z)\n\n    # pylint: enable=invalid-name\n\n  def multiply_cholesky(self, vector):\n    raise NotImplementedError(\"FullyConnectedSeriesFB does not support \"\n                              \"Cholesky computations.\")\n\n  def multiply_cholesky_inverse(self, vector):\n    raise NotImplementedError(\"FullyConnectedSeriesFB does not support \"\n                              \"Cholesky computations.\")\n\n"
  },
  {
    "path": "kfac/python/ops/fisher_factors.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"FisherFactor definitions.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nimport contextlib\nimport math\n# Dependency imports\nimport numpy as np\nimport six\nimport tensorflow.compat.v1 as tf\n\nfrom collections import OrderedDict\n\nfrom tensorflow.python.util import nest\nfrom kfac.python.ops import linear_operator as lo\nfrom kfac.python.ops import utils\n\n\n# Whether to initialize covariance estimators at a zero matrix (or the identity\n# matrix).\nINIT_COVARIANCES_AT_ZERO = True\n\n# Whether to zero-debias the moving averages.\nZERO_DEBIAS = True\n\n# Whether to initialize inverse (and other such matrices computed from the cov\n# matrices) to the zero matrix (or the identity matrix). Initializing to\n# zero is a safeguard against anything using the inverse before their first\n# proper update, and so is preferred.\nINIT_INVERSES_AT_ZERO = True\n\n# When the number of inverses requested from a FisherFactor is >= this value,\n# the inverses are computed using an eigenvalue decomposition.\nEIGENVALUE_DECOMPOSITION_THRESHOLD = 4\n\n# Numerical eigenvalues computed from covariance matrix estimates are clipped to\n# be at least as large as this value before they are used to compute inverses or\n# matrix powers. Must be nonnegative.\nEIGENVALUE_CLIPPING_THRESHOLD = 0.0\n\n# When approximating conv layer input factor using spatially uncorrelated\n# activations (`ConvInputSUAKroneckerfactor`) if this is True then assumes the\n# activations to have zero mean.\nASSUME_ZERO_MEAN_ACTIVATIONS = False\n\n# When approximating conv layer input factor using spatially uncorrelated\n# activations (`ConvInputSUAKroneckerfactor`) if this is True then do\n# mean subtraction from covariance matrix. Note this flag is only checked in the\n# case where ASSUME_ZERO_MEAN_ACTIVATIONS is set to True. If\n# ASSUME_ZERO_MEAN_ACTIVATIONS is False then mean is always subtracted from the\n# covariance matrix and this flag is redundant.\nSUBTRACT_MEAN_CONTRIB_FROM_COV = True\n\n# Subsample the inputs passed to the extract image patches. The number of\n# inputs is normally batch_size. If _SUB_SAMPLE_INPUTS = True then\n# the inputs will be randomly subsampled down to a total of\n# _INPUTS_TO_EXTRACT_PATCHES_FACTOR * batch_size.\n#\n# Note that the value of _SUB_SAMPLE_INPUTS can be overridden locally for a\n# particular layer by passing in an argument to the factor class (or the\n# registration function for the corresponding layer).\n_SUB_SAMPLE_INPUTS = False\n_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.2\n\n\n# Subsample the extracted image patches during covariance estimation for\n# input factors in conv layer. The number of patches subsampled will be\n# calculated based on the following formula:\n#\n# if _SUB_SAMPLE_PATCHES:\n#   num_patches = min(_MAX_NUM_PATCHES,\n#                     ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension))\n# else\n#   num_patches = total_patches\n#\n# where dimension is the number of rows (or columns) of the input factor matrix,\n# which is typically the number of input channels times the number of pixels\n# in a patch.\n#\n# Note that the value of _SUB_SAMPLE_PATCHES can be overridden locally for a\n# particular layer by passing in an argument to the factor class (or the\n# registration function for the corresponding layer).\n_SUB_SAMPLE_PATCHES = False\n_MAX_NUM_PATCHES = 10000000\n_MAX_NUM_PATCHES_PER_DIMENSION = 3.0\n\n\n# If true we use the custom XLA implementation of an op to compute the second\n# moment of the patch vectors. Note that _SUB_SAMPLE_PATCHES doesn't do anything\n# when this is enabled. Also note that _SUB_SAMPLE_INPUTS probably doesn't\n# need to be used either, since that feature was designed to mitigate the\n# extreme memory consumption of the naive implementation of this op.\n_USE_PATCHES_SECOND_MOMENT_OP = False\n\n\n# TOWER_STRATEGY can be one of \"concat\" or \"separate\".  If \"concat\", the data\n# passed to the factors from the blocks will be concatenated across towers\n# (lazily via PartitionedTensor objects).  Otherwise a tuple of tensors over\n# towers will be passed in, and the factors will iterate over this and do the\n# cov computations separately for each one, averaging the results together.\nTOWER_STRATEGY = \"separate\"\n#TOWER_STRATEGY = \"concat\"\n\n\n# The variable scope names can be edited by passing a custom sanitizer function.\n# By default the scope name is unchanged.\n_GET_SANITIZED_NAME_FN = lambda x: x\n\n\ndef set_global_constants(init_covariances_at_zero=None,\n                         zero_debias=None,\n                         init_inverses_at_zero=None,\n                         eigenvalue_decomposition_threshold=None,\n                         eigenvalue_clipping_threshold=None,\n                         assume_zero_mean_activations=None,\n                         subtract_mean_contrib_from_cov=None,\n                         sub_sample_inputs=None,\n                         inputs_to_extract_patches_factor=None,\n                         sub_sample_patches=None,\n                         max_num_patches=None,\n                         max_num_patches_per_dimension=None,\n                         tower_strategy=None,\n                         get_sanitized_name_fn=None,\n                         use_patches_second_moment_op=None):\n  \"\"\"Sets various global constants used by the classes in this module.\"\"\"\n  global INIT_COVARIANCES_AT_ZERO\n  global ZERO_DEBIAS\n  global INIT_INVERSES_AT_ZERO\n  global EIGENVALUE_DECOMPOSITION_THRESHOLD\n  global EIGENVALUE_CLIPPING_THRESHOLD\n  global ASSUME_ZERO_MEAN_ACTIVATIONS\n  global SUBTRACT_MEAN_CONTRIB_FROM_COV\n\n  global _SUB_SAMPLE_INPUTS\n  global _INPUTS_TO_EXTRACT_PATCHES_FACTOR\n  global _SUB_SAMPLE_PATCHES\n  global _MAX_NUM_PATCHES\n  global _MAX_NUM_PATCHES_PER_DIMENSION\n  global _GET_SANITIZED_NAME_FN\n  global TOWER_STRATEGY\n  global _USE_PATCHES_SECOND_MOMENT_OP\n\n  if init_covariances_at_zero is not None:\n    INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero\n  if zero_debias is not None:\n    ZERO_DEBIAS = zero_debias\n  if init_inverses_at_zero is not None:\n    INIT_INVERSES_AT_ZERO = init_inverses_at_zero\n  if eigenvalue_decomposition_threshold is not None:\n    EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold\n  if eigenvalue_clipping_threshold is not None:\n    EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold\n  if assume_zero_mean_activations is not None:\n    ASSUME_ZERO_MEAN_ACTIVATIONS = assume_zero_mean_activations\n  if subtract_mean_contrib_from_cov is not None:\n    SUBTRACT_MEAN_CONTRIB_FROM_COV = subtract_mean_contrib_from_cov\n  if sub_sample_inputs is not None:\n    _SUB_SAMPLE_INPUTS = sub_sample_inputs\n  if inputs_to_extract_patches_factor is not None:\n    _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor\n  if sub_sample_patches is not None:\n    _SUB_SAMPLE_PATCHES = sub_sample_patches\n  if max_num_patches is not None:\n    _MAX_NUM_PATCHES = max_num_patches\n  if max_num_patches_per_dimension is not None:\n    _MAX_NUM_PATCHES_PER_DIMENSION = max_num_patches_per_dimension\n  if tower_strategy is not None:\n    TOWER_STRATEGY = tower_strategy\n  if get_sanitized_name_fn is not None:\n    _GET_SANITIZED_NAME_FN = get_sanitized_name_fn\n  if use_patches_second_moment_op is not None:\n    _USE_PATCHES_SECOND_MOMENT_OP = use_patches_second_moment_op\n\n\nif INIT_INVERSES_AT_ZERO:\n  inverse_initializer = tf.zeros_initializer\nelse:\n  inverse_initializer = tf.initializers.identity\n\n\nif INIT_COVARIANCES_AT_ZERO:\n  covariance_initializer = tf.zeros_initializer\nelse:\n  covariance_initializer = tf.initializers.identity\n\n\nif INIT_COVARIANCES_AT_ZERO:\n  diagonal_covariance_initializer = tf.zeros_initializer\nelse:\n  diagonal_covariance_initializer = tf.ones_initializer\n\n\n@contextlib.contextmanager\ndef maybe_place_on_device(device):\n  if device is not None and len(device) and TOWER_STRATEGY == \"separate\":\n    with tf.device(device):\n      yield\n  else:\n    yield\n\n\ndef compute_cov(tensor, tensor_right=None, normalizer=None):\n  \"\"\"Compute the empirical second moment of the rows of a 2D Tensor.\n\n  This function is meant to be applied to random matrices for which the true row\n  mean is zero, so that the true second moment equals the true covariance.\n\n  Args:\n    tensor: A 2D Tensor.\n    tensor_right: An optional 2D Tensor. If provided, this function computes\n      the matrix product tensor^T * tensor_right instead of tensor^T * tensor.\n    normalizer: optional scalar for the estimator (by default, the normalizer is\n        the number of rows of tensor).\n\n  Returns:\n    A square 2D Tensor with as many rows/cols as the number of input columns.\n  \"\"\"\n  if normalizer is None:\n    normalizer = utils.get_shape(tensor)[0]\n  if tensor_right is None:\n    cov = (\n        tf.matmul(tensor, tensor, transpose_a=True) / tf.cast(\n            normalizer, tensor.dtype))\n    return (cov + tf.transpose(cov)) / tf.cast(2.0, cov.dtype)\n  else:\n    return (tf.matmul(tensor, tensor_right, transpose_a=True) /\n            tf.cast(normalizer, tensor.dtype))\n\n\ndef append_homog(tensor, homog_value=None):\n  \"\"\"Appends a homogeneous coordinate to the last dimension of a Tensor.\n\n  Args:\n    tensor: A Tensor.\n    homog_value: Value to append as homogeneous coordinate to the last dimension\n      of `tensor`.  If None 1.0 is used. (Default: None)\n\n  Returns:\n    A Tensor identical to the input but one larger in the last dimension.  The\n    new entries are filled with ones.\n  \"\"\"\n  shape = tensor.shape.as_list()\n  rank = len(shape)\n  if any(elt is None for elt in shape):\n    shape = tf.concat([tf.shape(tensor)[:-1], [1]], axis=0)\n  else:\n    shape[-1] = 1\n  if homog_value is not None:\n    appendage = homog_value * tf.ones(shape, dtype=tensor.dtype)\n  else:\n    appendage = tf.ones(shape, dtype=tensor.dtype)\n  return tf.concat([tensor, appendage], axis=-1)\n\n\ndef scope_string_from_params(params):\n  \"\"\"Builds a variable scope string name from the given parameters.\n\n  Supported parameters are:\n    * tensors\n    * booleans\n    * ints\n    * strings\n    * depth-1 tuples/lists of ints\n    * any depth tuples/lists of tensors\n  Other parameter types will throw an error.\n\n  Args:\n    params: A parameter or list of parameters.\n\n  Returns:\n    A string to use for the variable scope.\n\n  Raises:\n    ValueError: if params includes an unsupported type.\n  \"\"\"\n  params = params if isinstance(params, (tuple, list)) else (params,)\n\n  name_parts = []\n  for param in params:\n    if param is None:\n      name_parts.append(\"None\")\n    elif isinstance(param, (tuple, list)):\n      if all([isinstance(p, int) for p in param]):\n        name_parts.append(\"-\".join([str(p) for p in param]))\n      else:\n        name_parts.append(scope_string_from_name(param))\n    elif isinstance(param, (six.string_types, int, bool)):\n      name_parts.append(str(param))\n    elif isinstance(param, (tf.Tensor, tf.Variable)):\n      name_parts.append(scope_string_from_name(param))\n    elif isinstance(param, utils.PartitionedTensor):\n      name_parts.append(scope_string_from_name(param.tensors))\n    else:\n      raise ValueError(\"Encountered an unsupported param {} of type {}\".format(\n          param, type(param)))\n  return \"_\".join(name_parts)\n\n\ndef scope_string_from_name(tensor):\n  if isinstance(tensor, (tuple, list)):\n    return \"__\".join([scope_string_from_name(t) for t in tensor])\n  # \"gradients/add_4_grad/Reshape:0/replica_0\" ->\n  # \"gradients_add_4_grad_Reshape_0_replica_0\"\n  tensor_name = tensor.name.replace(\"/\", \"_\").replace(\":\", \"_\")\n  return _GET_SANITIZED_NAME_FN(tensor_name)\n\n\ndef scalar_or_tensor_to_string(val):\n  return repr(val) if np.isscalar(val) else scope_string_from_name(val)\n\n\ndef list_to_string(lst):\n  return \"_\".join(val if isinstance(val, six.string_types)\n                  else scalar_or_tensor_to_string(val) for val in lst)\n\n\ndef graph_func_to_id(func):\n  \"\"\"Returns a hashable object that represents func's computation.\"\"\"\n  # TODO(b/74201126): replace with Topohash of func's output\n  return func.func_id\n\n\ndef graph_func_to_string(func):\n  # TODO(b/74201126): replace with Topohash of func's output\n  return list_to_string(func.func_id)\n\n\ndef _subsample_patches(patches, name=None):\n  \"\"\"Subsample a patches matrix.\n\n  Subsample an array of image patches. The number of patches subsampled will be\n  calculated based on the following formula:\n\n  num_patches = min(_MAX_NUM_PATCHES,\n                    ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension))\n\n  Args:\n    patches: Tensor, of shape `[total_patches, dimension]`.\n    name: `string`, Default (None)\n\n  Returns:\n    A tensor of shape `[num_patches, dimension]`.\n\n  Raises:\n    ValueError: If patches is not matrix-shaped.\n    ValueError: If total_patches cannot be inferred.\n\n  \"\"\"\n  with tf.name_scope(name, \"subsample\", [patches]):\n    patches = tf.convert_to_tensor(patches)\n    if len(patches.shape) != 2:\n      raise ValueError(\"Input param patches must be a matrix.\")\n\n    total_patches = patches.shape.as_list()[0]\n    dimension = patches.shape.as_list()[1]\n    num_patches = min(_MAX_NUM_PATCHES,\n                      int(math.ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension)))\n\n    if total_patches is None:\n      total_patches = utils.get_shape(patches)[0]\n\n      should_subsample = tf.less(num_patches, total_patches)\n      return tf.cond(should_subsample,\n                     lambda: _random_tensor_gather(patches, num_patches, name),\n                     lambda: patches)\n    else:\n      if num_patches < total_patches:\n        return _random_tensor_gather(patches, num_patches, name)\n      else:\n        return patches\n\n\ndef _random_tensor_gather(array, num_ind, name=None):\n  \"\"\"Samples random indices of an array (along the first dimension).\n\n  Args:\n    array: Tensor of shape `[batch_size, ...]`.\n    num_ind: int. Number of indices to sample.\n    name: `string`. (Default: None)\n\n  Returns:\n    A tensor of shape `[num_ind, ...]`.\n  \"\"\"\n  with tf.name_scope(name, \"random_gather\", [array]):\n    array = tf.convert_to_tensor(array)\n    total_size = array.shape.as_list()[0]\n    if total_size is None:\n      total_size = utils.get_shape(array)[0]\n    indices = tf.random_shuffle(\n        tf.range(0, total_size, dtype=utils.preferred_int_dtype()))[:num_ind]\n    return tf.gather(array, indices, axis=0)\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass FisherFactor(object):\n  \"\"\"Base class for objects modeling factors of approximate Fisher blocks.\n\n  A FisherFactor represents part of an approximate Fisher Information matrix.\n  For example, one approximation to the Fisher uses the Kronecker product of two\n  FisherFactors A and B, F = kron(A, B). FisherFactors are composed with\n  FisherBlocks to construct a block-diagonal approximation to the full Fisher.\n\n  FisherFactors are backed by a single, non-trainable variable that is updated\n  by running FisherFactor.make_covariance_update_op(). The shape and type of\n  this variable is implementation specific.\n\n  Note that for blocks that aren't based on approximations, a 'factor' can\n  be the entire block itself, as is the case for the diagonal and full\n  representations.\n  \"\"\"\n\n  def __init__(self):\n    self._cov_tensor = None\n    self._cov = None\n    self._acc_cov = None\n\n  @abc.abstractproperty\n  def _var_scope(self):\n    \"\"\"Variable scope for this FisherFactor instance.\n\n    Returns:\n      string that unique identifies this FisherFactor instance.\n    \"\"\"\n    pass\n\n  @property\n  def name(self):\n    return self._var_scope\n\n  @abc.abstractproperty\n  def _cov_shape(self):\n    \"\"\"The shape of the variable backing this FisherFactor.\"\"\"\n    pass\n\n  @abc.abstractproperty\n  def _num_sources(self):\n    \"\"\"The number of things to sum over when updating covariance variable.\n\n    The default make_covariance_update_op function will call _compute_new_cov\n    with indices ranging from 0 to _num_sources-1. The typical situation is\n    where the factor wants to sum the statistics it computes over multiple\n    backpropped \"gradients\" (typically passed in via \"tensors\" or\n    \"outputs_grads\" arguments).\n    \"\"\"\n    pass\n\n  @abc.abstractproperty\n  def _num_towers(self):\n    pass\n\n  @abc.abstractproperty\n  def _dtype(self):\n    \"\"\"dtype for variable backing this factor.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def _partial_batch_size(self, source=0, tower=0):\n    \"\"\"Returns (partial) batch size associated with given source and tower.\"\"\"\n    pass\n\n  def batch_size(self, source=0):\n    \"\"\"Returns (total) batch size associated with given source.\"\"\"\n    return sum(self._partial_batch_size(source=source, tower=tower)\n               for tower in range(self._num_towers))\n\n  def check_partial_batch_sizes(self):\n    \"\"\"Ensures partial batch sizes are equal across towers and source.\"\"\"\n\n    # While it could be okay in principle to have different batch sizes for\n    # different towers, the way the code has been written isn't compatible with\n    # this. Basically, the normalizations occur for each tower and then the\n    # results are summed across towers and divided by the number of towers.\n    # The only way this is correct is if the towers all have the same batch\n    # size.\n\n    # Should make these messages use quote characters instead of parentheses\n    # when the bug with quote character rendering in assertion messages is\n    # fixed. See b/129476712\n    msg = (\"Inconsistent (partial) batch sizes detected for factor ({}) of type\"\n           \" {}. This can be caused by passing Tensors with the wrong sizes to \"\n           \"the registration functions, or misspecification of arguments like \"\n           \"batch_size, num_uses, or num_timesteps.\".format(\n               self.name, utils.cls_name(self)))\n\n    partial_batch_size = self._partial_batch_size()\n\n    if self._num_sources > 1 or self._num_towers > 1:\n      if isinstance(partial_batch_size, int):\n        checks = tuple(\n            partial_batch_size == self._partial_batch_size(source=source,\n                                                           tower=tower)\n            for source, tower in zip(range(self._num_sources),\n                                     range(self._num_towers)))\n        if not all(checks):\n          raise ValueError(msg)\n\n        return tf.no_op()\n\n      else:\n        asserts = tuple(\n            tf.assert_equal(partial_batch_size,\n                            self._partial_batch_size(source=source,\n                                                     tower=tower),\n                            message=msg)\n            for source, tower in zip(range(self._num_sources),\n                                     range(self._num_towers)))\n        return tf.group(asserts)\n\n    return tf.no_op()\n\n  @property\n  def _cov_initializer(self):\n    \"\"\"Function for initializing covariance variable.\"\"\"\n    return covariance_initializer\n\n  def instantiate_cov_variables(self):\n    \"\"\"Makes the internal cov variable(s).\"\"\"\n    assert self._cov is None\n    with tf.variable_scope(self._var_scope):\n      self._cov = utils.MovingAverageVariable(\n          name=\"cov\",\n          shape=self._cov_shape,\n          dtype=self._dtype,\n          initializer=self._cov_initializer,\n          normalize_value=ZERO_DEBIAS)\n\n  @abc.abstractmethod\n  def _compute_new_cov(self, source, tower):\n    \"\"\"Computes minibatch-estimated covariance for a single source.\n\n    Args:\n      source: int in [0, self._num_sources). Which source to use when computing\n        the cov update.\n      tower: int in [0, self._num_towers). Which tower to use when computing\n        the cov update.\n\n    Returns:\n      Tensor of same shape as self.cov.\n    \"\"\"\n    pass\n\n  def _compute_total_new_cov(self):\n    \"\"\"Computes covariance by summing across (source, towers).\"\"\"\n    new_cov_contribs = []\n    for source in range(self._num_sources):\n      for tower in range(self._num_towers):\n        with maybe_place_on_device(self._get_data_device(tower)):\n          new_cov_contribs.append(self._compute_new_cov(source, tower))\n\n    new_cov = tf.add_n(new_cov_contribs) / float(self._num_towers)\n\n    # Compute average of 'new_cov' across all replicas. On a replica, each\n    # instance of 'new_cov' will be based on a different minibatch. This ensures\n    # that by the time variable assignment happens, all replicas have the same\n    # value.\n    #\n    # Other implementations of make_covariance_update_op() that accumulate\n    # statistics in other variables should mimic this behavior.\n    #\n    # NOTE: communicating this matrix at every iteration is wasteful in the\n    # sense that we might only need fresh copies when we do the inversions.\n    # (Although be careful about factors [e.g. diagonal] or ops\n    # [e.g. multiply()] that directly use the cov vars instead of the inv vars!)\n    new_cov = utils.all_average(new_cov)\n\n    return new_cov\n\n  def make_covariance_update_op(self, ema_decay, ema_weight):\n    \"\"\"Constructs and returns the covariance update Op.\n\n    Args:\n      ema_decay: float or Tensor. The exponential moving average decay.\n      ema_weight: float or Tensor. The weight to put on the newly computed values.\n        This is typically 1.0 - ema_decay.\n\n    Returns:\n      The op which updates the cov variable (via acc_cov).\n    \"\"\"\n    cov_tensor = self._compute_total_new_cov()\n    self._cov_tensor = cov_tensor  # This is used for non-standard applications\n                                   # and debugging I think.\n\n    return self._cov.add_to_average(cov_tensor, decay=ema_decay,\n                                    weight=ema_weight)\n\n  @abc.abstractmethod\n  def _get_data_device(self, tower):\n    pass\n\n  @abc.abstractmethod\n  def instantiate_inv_variables(self):\n    \"\"\"Makes the internal \"inverse\" variable(s).\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def make_inverse_update_ops(self):\n    \"\"\"Create and return update ops corresponding to registered computations.\"\"\"\n    pass\n\n  @property\n  def cov(self):\n    return self._cov.value\n\n  def get_cov_vars(self):\n    return [self.cov]\n\n  def get_inv_vars(self):\n    return []\n\n  @abc.abstractmethod\n  def get_cov_as_linear_operator(self):\n    \"\"\"Returns `LinearOperator` instance which wraps the cov matrix.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def register_matpower(self, exp, damping_func):\n    pass\n\n  @abc.abstractmethod\n  def register_cholesky(self, damping_func):\n    pass\n\n  @abc.abstractmethod\n  def register_cholesky_inverse(self, damping_func):\n    pass\n\n  @abc.abstractmethod\n  def get_matpower(self, exp, damping_func):\n    pass\n\n  @abc.abstractmethod\n  def get_cholesky(self, damping_func):\n    pass\n\n  @abc.abstractmethod\n  def get_cholesky_inverse(self, damping_func):\n    pass\n\n\nclass DenseSquareMatrixFactor(FisherFactor):\n  \"\"\"Base class for FisherFactors that are stored as dense square matrices.\n\n  This class explicitly calculates and stores inverses of their `cov` matrices,\n  which must be square dense matrices.\n\n  Subclasses must implement the _compute_new_cov method, and the _var_scope and\n  _cov_shape properties.\n  \"\"\"\n\n  # TODO(b/69108481): This class (and its subclasses) should be refactored to\n  # serve the matrix quantities it computes as both (potentially stale)\n  # variables, updated by the inverse update ops, and fresh values stored in\n  # tensors that recomputed once every session.run() call.  Currently matpower\n  # and damp_inverse have the former behavior, while eigendecomposition has\n  # the latter.\n\n  def __init__(self):\n    self._matpower_by_exp_and_damping = OrderedDict()  # { (float, hashable): variable }\n    self._matpower_registrations = set()  # { (float, hashable) }\n    self._eigendecomp = None\n    self._damping_funcs_by_id = OrderedDict()  # {hashable: lambda}\n\n    self._cholesky_registrations = set()  # { hashable }\n    self._cholesky_inverse_registrations = set()  # { hashable }\n\n    self._cholesky_by_damping = OrderedDict()  # { hashable: variable }\n    self._cholesky_inverse_by_damping = OrderedDict()  # { hashable: variable }\n\n    super(DenseSquareMatrixFactor, self).__init__()\n\n  def get_cov_as_linear_operator(self):\n    \"\"\"Returns `LinearOperator` instance which wraps the cov matrix.\"\"\"\n    assert self.cov.shape.ndims == 2\n    return lo.LinearOperatorFullMatrix(self.cov,\n                                       is_self_adjoint=True,\n                                       is_square=True)\n\n  def _register_damping(self, damping_func):\n    damping_id = graph_func_to_id(damping_func)\n    if damping_id not in self._damping_funcs_by_id:\n      self._damping_funcs_by_id[damping_id] = damping_func\n    return damping_id\n\n  def register_inverse(self, damping_func):\n    # Just for backwards compatibility of some old code and tests\n    self.register_matpower(-1, damping_func)\n\n  def register_matpower(self, exp, damping_func):\n    \"\"\"Registers a matrix power to be maintained and served on demand.\n\n    This creates a variable and signals make_inverse_update_ops to make the\n    corresponding update op.  The variable can be read via the method\n    get_matpower.\n\n    Args:\n      exp: float.  The exponent to use in the matrix power.\n      damping_func: A function that computes a 0-D Tensor or a float which will\n        be the damping value used.  i.e. damping = damping_func().\n    \"\"\"\n    if exp == 1.0:\n      return\n\n    damping_id = self._register_damping(damping_func)\n\n    if (exp, damping_id) not in self._matpower_registrations:\n      self._matpower_registrations.add((exp, damping_id))\n\n  def register_cholesky(self, damping_func):\n    \"\"\"Registers a Cholesky factor to be maintained and served on demand.\n\n    This creates a variable and signals make_inverse_update_ops to make the\n    corresponding update op.  The variable can be read via the method\n    get_cholesky.\n\n    Args:\n      damping_func: A function that computes a 0-D Tensor or a float which will\n        be the damping value used.  i.e. damping = damping_func().\n    \"\"\"\n    damping_id = self._register_damping(damping_func)\n\n    if damping_id not in self._cholesky_registrations:\n      self._cholesky_registrations.add(damping_id)\n\n  def register_cholesky_inverse(self, damping_func):\n    \"\"\"Registers an inverse Cholesky factor to be maintained/served on demand.\n\n    This creates a variable and signals make_inverse_update_ops to make the\n    corresponding update op.  The variable can be read via the method\n    get_cholesky_inverse.\n\n    Args:\n      damping_func: A function that computes a 0-D Tensor or a float which will\n        be the damping value used.  i.e. damping = damping_func().\n    \"\"\"\n    damping_id = self._register_damping(damping_func)\n\n    if damping_id not in self._cholesky_inverse_registrations:\n      self._cholesky_inverse_registrations.add(damping_id)\n\n  def get_inv_vars(self):\n    inv_vars = []\n    inv_vars.extend(self._matpower_by_exp_and_damping.values())\n    inv_vars.extend(self._cholesky_by_damping.values())\n    inv_vars.extend(self._cholesky_inverse_by_damping.values())\n    return inv_vars\n\n  def instantiate_inv_variables(self):\n    \"\"\"Makes the internal \"inverse\" variable(s).\"\"\"\n\n    for (exp, damping_id) in self._matpower_registrations:\n      exp_string = scalar_or_tensor_to_string(exp)\n      damping_func = self._damping_funcs_by_id[damping_id]\n      damping_string = graph_func_to_string(damping_func)\n      with tf.variable_scope(self._var_scope):\n        matpower = tf.get_variable(\n            \"matpower_exp{}_damp{}\".format(exp_string, damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n      assert (exp, damping_id) not in self._matpower_by_exp_and_damping\n      self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower\n\n    for damping_id in self._cholesky_registrations:\n      damping_func = self._damping_funcs_by_id[damping_id]\n      damping_string = graph_func_to_string(damping_func)\n      with tf.variable_scope(self._var_scope):\n        chol = tf.get_variable(\n            \"cholesky_damp{}\".format(damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n      assert damping_id not in self._cholesky_by_damping\n      self._cholesky_by_damping[damping_id] = chol\n\n    for damping_id in self._cholesky_inverse_registrations:\n      damping_func = self._damping_funcs_by_id[damping_id]\n      damping_string = graph_func_to_string(damping_func)\n      with tf.variable_scope(self._var_scope):\n        cholinv = tf.get_variable(\n            \"cholesky_inverse_damp{}\".format(damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n      assert damping_id not in self._cholesky_inverse_by_damping\n      self._cholesky_inverse_by_damping[damping_id] = cholinv\n\n  def make_inverse_update_ops(self):\n    \"\"\"Create and return update ops corresponding to registered computations.\"\"\"\n    ops = []\n\n    num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping\n                       if exp == -1)\n\n    num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses\n\n    other_matrix_power_registered = num_other_matpower >= 1\n\n    use_eig = (\n        self._eigendecomp or other_matrix_power_registered or\n        num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)\n\n    # We precompute these so we don't need to evaluate them multiple times (for\n    # each matrix power that uses them)\n    damping_value_by_id = {damping_id: tf.cast(\n        self._damping_funcs_by_id[damping_id](), self._dtype)\n                           for damping_id in self._damping_funcs_by_id}\n\n    if use_eig:\n      eigenvalues, eigenvectors = self.get_eigendecomp()  # pylint: disable=unpacking-non-sequence\n\n      for (exp, damping_id), matpower in (\n          self._matpower_by_exp_and_damping.items()):\n        damping = damping_value_by_id[damping_id]\n        ops.append(\n            utils.smart_assign(\n                matpower,\n                tf.matmul(eigenvectors * (eigenvalues + damping)**exp,\n                          tf.transpose(eigenvectors))))\n      # These ops share computation and should be run on a single device.\n      ops = [tf.group(*ops)]\n    else:\n      for (exp, damping_id), matpower in (\n          self._matpower_by_exp_and_damping.items()):\n        assert exp == -1\n        damping = damping_value_by_id[damping_id]\n        ops.append(\n            utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping)))\n\n    # TODO(b/77902055): If inverses are being computed with Cholesky's\n    # we can share the work. Instead this code currently just computes the\n    # Cholesky a second time. It does at least share work between requests for\n    # Cholesky's and Cholesky inverses with the same damping id.\n    for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():\n      cholesky_ops = []\n\n      damping = damping_value_by_id[damping_id]\n      cholesky_value = utils.cholesky(self.cov, damping)\n\n      if damping_id in self._cholesky_by_damping:\n        cholesky = self._cholesky_by_damping[damping_id]\n        cholesky_ops.append(utils.smart_assign(cholesky, cholesky_value))\n\n      identity = tf.eye(\n          cholesky_value.shape.as_list()[0], dtype=cholesky_value.dtype)\n      cholesky_inv_value = tf.matrix_triangular_solve(cholesky_value, identity)\n      cholesky_ops.append(utils.smart_assign(cholesky_inv, cholesky_inv_value))\n\n      ops.append(tf.group(*cholesky_ops))\n\n    for damping_id, cholesky in self._cholesky_by_damping.items():\n      if damping_id not in self._cholesky_inverse_by_damping:\n        damping = damping_value_by_id[damping_id]\n        cholesky_value = utils.cholesky(self.cov, damping)\n        ops.append(utils.smart_assign(cholesky, cholesky_value))\n\n    self._eigendecomp = False\n    return ops\n\n  def get_inverse(self, damping_func):\n    # Just for backwards compatibility of some old code and tests\n    return self.get_matpower(-1, damping_func)\n\n  def get_matpower(self, exp, damping_func):\n    # Note that this function returns a variable which gets updated by the\n    # inverse ops.  It may be stale / inconsistent with the latest value of\n    # self.cov (except when exp == 1).\n    if exp != 1:\n      damping_id = graph_func_to_id(damping_func)\n      matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]\n    else:\n      cov = self.cov\n      identity = tf.eye(cov.shape.as_list()[0], dtype=cov.dtype)\n      matpower = cov + tf.cast(damping_func(), dtype=self.cov.dtype)*identity\n\n    assert matpower.shape.ndims == 2\n    return lo.LinearOperatorFullMatrix(matpower,\n                                       is_non_singular=True,\n                                       is_self_adjoint=True,\n                                       is_positive_definite=True,\n                                       is_square=True)\n\n  def get_cholesky(self, damping_func):\n    # Note that this function returns a variable which gets updated by the\n    # inverse ops.  It may be stale / inconsistent with the latest value of\n    # self.cov.\n    damping_id = graph_func_to_id(damping_func)\n    cholesky = self._cholesky_by_damping[damping_id]\n    assert cholesky.shape.ndims == 2\n    return lo.LinearOperatorFullMatrix(cholesky,\n                                       is_non_singular=True,\n                                       is_square=True)\n\n  def get_cholesky_inverse(self, damping_func):\n    # Note that this function returns a variable which gets updated by the\n    # inverse ops.  It may be stale / inconsistent with the latest value of\n    # self.cov.\n    damping_id = graph_func_to_id(damping_func)\n    cholesky_inv = self._cholesky_inverse_by_damping[damping_id]\n    assert cholesky_inv.shape.ndims == 2\n    return lo.LinearOperatorFullMatrix(cholesky_inv,\n                                       is_non_singular=True,\n                                       is_square=True)\n\n  def get_eigendecomp(self):\n    \"\"\"Creates or retrieves eigendecomposition of self._cov.\"\"\"\n    # Unlike get_matpower this doesn't retrieve a stored variable, but instead\n    # always computes a fresh version from the current value of self.cov.\n    if not self._eigendecomp:\n      eigenvalues, eigenvectors = tf.self_adjoint_eig(self.cov)\n\n      # The matrix self._cov is positive semidefinite by construction, but the\n      # numerical eigenvalues could be negative due to numerical errors, so here\n      # we clip them to be at least FLAGS.eigenvalue_clipping_threshold\n      clipped_eigenvalues = tf.maximum(eigenvalues,\n                                       EIGENVALUE_CLIPPING_THRESHOLD)\n      self._eigendecomp = (clipped_eigenvalues, eigenvectors)\n\n    return self._eigendecomp\n\n\nclass NaiveFullFactor(DenseSquareMatrixFactor):\n  \"\"\"FisherFactor for a full matrix representation of the Fisher of a parameter.\n\n  Note that this uses the naive \"square the sum estimator\", and so is applicable\n  to any type of parameter in principle, but has very high variance.\n  \"\"\"\n\n  def __init__(self,\n               params_grads,\n               batch_size):\n    self._batch_size = batch_size\n    self._params_grads = tuple(utils.ensure_sequence(params_grad)\n                               for params_grad in params_grads)\n    super(NaiveFullFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_naivefull_\" + scope_string_from_params(\n        [self._params_grads, self._batch_size])\n\n  @property\n  def _cov_shape(self):\n    size = sum(param_grad.shape.num_elements()\n               for param_grad in self._params_grads[0])\n    return (size, size)\n\n  @property\n  def _num_sources(self):\n    return len(self._params_grads)\n\n  @property\n  def _num_towers(self):\n    return 1\n\n  @property\n  def _dtype(self):\n    return self._params_grads[0][0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    assert source == 0 and tower == 0\n    return self._batch_size\n\n  def _compute_new_cov(self, source, tower):\n    assert tower == 0\n\n    # This will be a very basic rank 1 estimate\n    params_grads_flat = utils.tensors_to_column(self._params_grads[source])\n    return ((params_grads_flat * tf.transpose(params_grads_flat)) / tf.cast(\n        self._batch_size, params_grads_flat.dtype))\n\n  def _get_data_device(self, tower):\n    return None\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass DiagonalFactor(FisherFactor):\n  \"\"\"A base class for FisherFactors that use diagonal approximations.\n\n  A DiagonalFactor's covariance variable can be of any shape, but must contain\n  exactly one entry per parameter.\n  \"\"\"\n\n  def get_cov_as_linear_operator(self):\n    \"\"\"Returns `LinearOperator` instance which wraps the cov matrix.\"\"\"\n    return lo.LinearOperatorDiag(self._matrix_diagonal,\n                                 is_self_adjoint=True,\n                                 is_square=True)\n\n  @property\n  def _cov_initializer(self):\n    return diagonal_covariance_initializer\n\n  @property\n  def _matrix_diagonal(self):\n    return tf.reshape(self.cov, [-1])\n\n  def make_inverse_update_ops(self):\n    return []\n\n  def instantiate_inv_variables(self):\n    pass\n\n  def register_matpower(self, exp, damping_func):\n    pass\n\n  def register_cholesky(self, damping_func):\n    pass\n\n  def register_cholesky_inverse(self, damping_func):\n    pass\n\n  def get_matpower(self, exp, damping_func):\n    matpower_diagonal = (self._matrix_diagonal\n                         + tf.cast(damping_func(), self._dtype))**exp\n    return lo.LinearOperatorDiag(matpower_diagonal,\n                                 is_non_singular=True,\n                                 is_self_adjoint=True,\n                                 is_positive_definite=True,\n                                 is_square=True)\n\n  def get_cholesky(self, damping_func):\n    return self.get_matpower(0.5, damping_func)\n\n  def get_cholesky_inverse(self, damping_func):\n    return self.get_matpower(-0.5, damping_func)\n\n\nclass NaiveDiagonalFactor(DiagonalFactor):\n  \"\"\"FisherFactor for a diagonal approximation of any type of param's Fisher.\n\n  Note that this uses the naive \"square the sum estimator\", and so is applicable\n  to any type of parameter in principle, but has very high variance.\n  \"\"\"\n\n  def __init__(self,\n               params_grads,\n               batch_size):\n    \"\"\"Initializes NaiveDiagonalFactor instance.\n\n    Args:\n      params_grads: List of tensors (or lists), with the first index\n        corresponding to source, and the second optional index corresponding\n        to the element of the parameter list.\n      batch_size: int or 0-D Tensor. The batch size.\n    \"\"\"\n    self._params_grads = params_grads\n    self._batch_size = batch_size\n    super(NaiveDiagonalFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_naivediag_\" + scope_string_from_params(\n        [self._params_grads, self._batch_size])\n\n  @property\n  def _cov_shape(self):\n    return self._params_grads[0].shape\n\n  @property\n  def _num_sources(self):\n    return len(self._params_grads)\n\n  @property\n  def _num_towers(self):\n    return 1\n\n  @property\n  def _dtype(self):\n    return self._params_grads[0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    assert source == 0 and tower == 0\n    return self._batch_size\n\n  def _compute_new_cov(self, source, tower):\n    assert tower == 0\n    return (tf.square(self._params_grads[source]) / tf.cast(\n        self._batch_size, self._params_grads[source].dtype))\n\n  def _get_data_device(self, tower):\n    return None\n\n\nclass DiagonalKroneckerFactor(DiagonalFactor):\n  \"\"\"A Kronecker FisherFactor using diagonal approximations.\n\n  This class handles both sparse and dense inputs. The covariance is estimated\n  using the diagonal covariance matrix. For a dense tensor:\n\n    Cov(inputs, inputs) = (1/batch_size) sum_{i} diag(inputs[i,:] ** 2).\n\n  For sparse inputs, one of the most common use cases is the sparse input to an\n  embedding layer. Given tensor = [batch_size, input_size] representing\n  indices into an [vocab_size, embedding_size] embedding matrix, the diagonal\n  covariance matrix is\n\n    Cov(inputs, inputs) =\n        (1/batch_size) sum_{i} diag(n_hot(inputs[i]) ** 2).\n\n  where inputs[i] is the ith list of input ids, n_hot() constructs an n-hot\n  binary vector and diag() constructs a diagonal matrix of size\n  [vocab_size, vocab_size].\n  \"\"\"\n\n  def __init__(self, tensors, has_bias=False, dtype=None):\n    \"\"\"Instantiate DiagonalKroneckerFactor.\n\n    Args:\n      tensors: List of list of Tensors, each of shape [batch_size, n]. First\n        index is source, second index is tower. Two types of tensors are\n        supported. Dense tensors are typically either a layer's inputs or its\n        output's gradients. Sparse tensors are typically indices into an\n        [vocab_size, embedding_dim] embedding matrix. Sparse tensors must have\n        a property named \"one_hot_depth\" indicating the depth of one-hot tensors\n        they should be converted to.\n      dtype: dtype for covariance statistics. Only used for sparse inputs. Must\n        be a floating point type. Defaults to float32.\n      has_bias: bool. If True, append '1' to each input.\n    \"\"\"\n    self._tensors = tensors\n    dtype = dtype or tf.float32\n    self._has_bias = has_bias\n    self._one_hot_depth = getattr(self._tensors[0][0], \"one_hot_depth\", None)\n    if self._one_hot_depth is None:\n      self._dense_input = True\n      self._cov_dtype = self._tensors[0][0].dtype\n    else:\n      self._dense_input = False\n      self._cov_dtype = dtype\n\n    super(DiagonalKroneckerFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_diag_kron_\" + scope_string_from_params(\n        nest.flatten(self._tensors))\n\n  @property\n  def _cov_shape(self):\n    if self._dense_input:\n      size = self._tensors[0][0].shape[1] + self._has_bias\n    else:\n      size = self._one_hot_depth + self._has_bias\n    return [size]\n\n  @property\n  def _num_sources(self):\n    return len(self._tensors)\n\n  @property\n  def _num_towers(self):\n    return len(self._tensors[0])\n\n  @property\n  def _dtype(self):\n    return self._cov_dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    return utils.get_shape(self._tensors[source][tower])[0]\n\n  def _compute_new_cov(self, source, tower):\n    return self._compute_new_cov_from_tensor(self._tensors[source][tower])\n\n  def _compute_new_cov_from_tensor(self, tensor):\n    batch_size = utils.get_shape(tensor)[0]\n\n    if self._dense_input:\n      if len(tensor.shape) != 2:\n        raise ValueError(\n            \"Dense input tensors to DiagonalKroneckerFactor must have \"\n            \"rank == 2. Found tensor with wrong rank: {}\".format(tensor))\n      new_cov = tf.square(tensor)\n    else:\n      if len(tensor.shape) != 1:\n        raise ValueError(\n            \"Sparse input tensors to DiagonalKroneckerFactor must have \"\n            \"rank == 1. Found tensor with wrong rank: {}\".format(tensor))\n      # Transform indices into one-hot vectors.\n      #\n      # TODO(b/72714822): There must be a faster way to construct the diagonal\n      # covariance matrix! This operation is O(batch_size * vocab_size), where\n      # it should be O(batch_size * input_size).\n      flat_input_ids = tf.reshape(tensor, [-1])\n      new_cov = tf.one_hot(flat_input_ids,\n                           self._one_hot_depth)  # [?, vocab_size]\n\n      # Take average across examples. Note that, because all entries have\n      # magnitude zero or one, there's no need to square the entries.\n      #\n      # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation\n      # within an example such as average.\n      #\n      # TODO(b/72714822): Support for partitioned embeddings.\n\n    new_cov = tf.reduce_sum(new_cov, axis=0)\n    new_cov /= tf.cast(batch_size, new_cov.dtype)\n\n    if self._has_bias:\n      new_cov = append_homog(new_cov)\n\n    return new_cov\n\n  def _get_data_device(self, tower):\n    return self._tensors[0][tower].device\n\n\nclass DiagonalMultiKF(DiagonalKroneckerFactor):\n\n  def __init__(self, tensors, num_uses, has_bias=False, dtype=None):\n    super(DiagonalMultiKF, self).__init__(\n        tensors, dtype=dtype, has_bias=has_bias)\n    self._num_uses = num_uses\n\n  def _partial_batch_size(self, source=0, tower=0):\n    # Note that some internal comptutations of \"batch_size\" done in the parent\n    # class won't actually be the proper batch size. Instead, they will be\n    # just \"the thing to normalize the statistics by\", essentially. This is okay\n    # as we don't mix the two things up.\n    shape = utils.get_shape(self._tensors[source][tower])\n    if self._dense_input:\n      if len(shape) == 2:\n        # the folded case\n        return shape[0] // self._num_uses\n      elif len(shape) == 3:\n        return shape[1]  # batch is the second dim\n    else:\n      if len(shape) == 1:\n        # the folded case\n        return shape[0] // self._num_uses\n      elif len(shape) == 2:\n        return shape[1]  # batch is the second dim\n\n  @property\n  def _cov_shape(self):\n    if self._dense_input:\n      shape = self._tensors[0][0].shape\n      if len(shape) == 2:\n        size = shape[1] + self._has_bias\n      elif len(shape) == 3:\n        size = shape[2] + self._has_bias\n    else:\n      size = self._one_hot_depth + self._has_bias\n    return [size]\n\n  def _compute_new_cov(self, source, tower):\n    tensor = self._tensors[source][tower]\n    if self._dense_input:\n      if len(tensor.shape) == 3:\n        tensor = tf.reshape(tensor, [-1, tensor.shape[2]])\n    else:\n      if len(tensor.shape) == 2:\n        tensor = tf.reshape(tensor, [-1])\n\n    return self._compute_new_cov_from_tensor(tensor)\n\n\nclass FullyConnectedDiagonalFactor(DiagonalFactor):\n  r\"\"\"FisherFactor for a diagonal approx of a fully-connected layer's Fisher.\n\n  Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],\n  approximates the covariance as,\n\n    Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0\n\n  where the square is taken element-wise.\n  \"\"\"\n\n  def __init__(self,\n               inputs,\n               outputs_grads,\n               has_bias=False):\n    \"\"\"Instantiate FullyConnectedDiagonalFactor.\n\n    Args:\n      inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this\n        layer.  List index is towers.\n      outputs_grads: List of Tensors, each of shape [batch_size, output_size],\n        which are the gradients of the loss with respect to the layer's\n        outputs. First index is source, second is tower.\n\n      has_bias: bool. If True, append '1' to each input.\n    \"\"\"\n    self._inputs = inputs\n    self._has_bias = has_bias\n    self._outputs_grads = outputs_grads\n    self._squared_inputs = None\n\n    super(FullyConnectedDiagonalFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_diagfc_\" + scope_string_from_params(\n        tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))\n\n  @property\n  def _cov_shape(self):\n    input_size = self._inputs[0].shape[1] + self._has_bias\n    output_size = self._outputs_grads[0][0].shape[1]\n    return [input_size, output_size]\n\n  @property\n  def _num_sources(self):\n    return len(self._outputs_grads)\n\n  @property\n  def _num_towers(self):\n    return len(self._inputs)\n\n  @property\n  def _dtype(self):\n    return self._outputs_grads[0][0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    return utils.get_shape(self._outputs_grads[source][tower])[0]\n\n  def make_covariance_update_op(self, ema_decay, ema_weight):\n\n    self._squared_inputs = []\n    for tower in range(self._num_towers):\n      inputs = self._inputs[tower]\n\n      with maybe_place_on_device(self._get_data_device(tower)):\n        if self._has_bias:\n          inputs = append_homog(inputs)\n        self._squared_inputs.append(tf.square(inputs))\n\n    return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(\n        ema_decay, ema_weight)\n\n  def _compute_new_cov(self, source, tower):\n    batch_size = utils.get_shape(self._squared_inputs[tower])[0]\n\n    outputs_grad = self._outputs_grads[source][tower]\n\n    # The well-known special formula that uses the fact that the entry-wise\n    # square of an outer product is the outer-product of the entry-wise squares.\n    # The gradient is the outer product of the input and the output gradients,\n    # so we just square both and then take their outer-product.\n    new_cov = tf.matmul(\n        self._squared_inputs[tower], tf.square(outputs_grad), transpose_a=True)\n    new_cov /= tf.cast(batch_size, new_cov.dtype)\n    return new_cov\n\n  def _get_data_device(self, tower):\n    return self._inputs[tower].device\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass ScaleAndShiftFactor(FisherFactor):\n\n  def __init__(self,\n               inputs,\n               outputs_grads,\n               broadcast_dims_scale,\n               broadcast_dims_shift=None,\n               has_shift=True,\n               approx=\"full\"):\n\n    assert approx == \"full\" or approx == \"diagonal\"\n\n    self._inputs = inputs\n    self._outputs_grads = outputs_grads\n    self._broadcast_dims_scale = broadcast_dims_scale\n    self._broadcast_dims_shift = broadcast_dims_shift\n    self._has_shift = has_shift\n    self._approx = approx\n\n    assert not has_shift or broadcast_dims_shift is not None\n\n    super(ScaleAndShiftFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_scaleshift_\" + scope_string_from_params(\n        [self._inputs, self._outputs_grads, self._broadcast_dims_scale,\n         self._broadcast_dims_shift, self._has_shift, self._approx])\n\n  @property\n  def _cov_shape(self):\n    size = np.prod([\n        self._inputs[0].shape[i]\n        for i in range(1, len(self._inputs[0].shape))\n        if i not in self._broadcast_dims_scale],\n                   dtype=np.int64)\n\n    if self._has_shift:\n      size_shift = np.prod([\n          self._outputs_grads[0][0].shape[i]\n          for i in range(1, len(self._outputs_grads[0][0].shape))\n          if i not in self._broadcast_dims_shift],\n                           dtype=np.int64)\n      size += size_shift\n\n    if self._approx == \"full\":\n      return (size, size)\n    elif self._approx == \"diagonal\":\n      return (size,)\n\n  @property\n  def _num_sources(self):\n    return len(self._outputs_grads)\n\n  @property\n  def _num_towers(self):\n    return len(self._inputs)\n\n  @property\n  def _dtype(self):\n    return self._inputs[0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    return utils.get_shape(self._outputs_grads[source][tower])[0]\n\n  def _compute_new_cov(self, source, tower):\n    # Here we implement a \"sum of squares\" estimator that uses the special\n    # structure of the scale & shift operation. In particular, we sum across\n    # all dimensions that broadcast, then square (or take outer-products), and\n    # then average across the mini-batch.\n\n    inputs = self._inputs[tower]\n    outputs_grad = self._outputs_grads[source][tower]\n    batch_size = utils.get_shape(inputs)[0]\n\n    assert len(inputs.shape) == len(outputs_grad.shape)\n    for i in range(1, len(inputs.shape)):\n      assert inputs.shape[i] <= outputs_grad.shape[i]\n\n    # The formula for the gradient of the shift param is just the element-wise\n    # product of the inputs and the output gradients, summed across the\n    # dimensions that get broadcasted.\n    scale_grads = tf.reduce_sum(inputs * outputs_grad,\n                                axis=self._broadcast_dims_scale)\n    scale_grads_flat = tf.reshape(scale_grads, [batch_size, -1])\n\n    if self._has_shift:\n      # The formula for the gradient of the shift param is just the output\n      # gradients, summed across the dimensions that get broadcasted.\n      shift_grads = tf.reduce_sum(outputs_grad,\n                                  axis=self._broadcast_dims_shift)\n      shift_grads_flat = tf.reshape(shift_grads, [batch_size, -1])\n\n      params_grads_flat = tf.concat([scale_grads_flat, shift_grads_flat],\n                                    axis=1)\n    else:\n      params_grads_flat = scale_grads_flat\n\n    if self._approx == \"full\":\n      new_cov = compute_cov(params_grads_flat)\n\n    elif self._approx == \"diagonal\":\n      new_cov = tf.reduce_mean(tf.square(params_grads_flat), axis=0)\n\n    return new_cov\n\n  def _get_data_device(self, tower):\n    return self._inputs[tower].device\n\n\nclass ScaleAndShiftFullFactor(ScaleAndShiftFactor, DenseSquareMatrixFactor):\n\n  def __init__(self,\n               inputs,\n               outputs_grads,\n               broadcast_dims_scale,\n               broadcast_dims_shift=None,\n               has_shift=True):\n\n    super(ScaleAndShiftFullFactor, self).__init__(\n        inputs,\n        outputs_grads,\n        broadcast_dims_scale,\n        broadcast_dims_shift=broadcast_dims_shift,\n        has_shift=has_shift,\n        approx=\"full\")\n\n\nclass ScaleAndShiftDiagonalFactor(ScaleAndShiftFactor, DiagonalFactor):\n\n  def __init__(self,\n               inputs,\n               outputs_grads,\n               broadcast_dims_scale,\n               broadcast_dims_shift=None,\n               has_shift=True):\n\n    super(ScaleAndShiftDiagonalFactor, self).__init__(\n        inputs,\n        outputs_grads,\n        broadcast_dims_scale,\n        broadcast_dims_shift=broadcast_dims_shift,\n        has_shift=has_shift,\n        approx=\"diagonal\")\n\n\nclass ConvDiagonalFactor(DiagonalFactor):\n  \"\"\"FisherFactor for a diagonal approx of a convolutional layer's Fisher.\"\"\"\n\n  def __init__(self,\n               inputs,\n               outputs_grads,\n               filter_shape,\n               strides,\n               padding,\n               data_format=None,\n               dilations=None,\n               has_bias=False,\n               patch_mask=None):\n    \"\"\"Creates a ConvDiagonalFactor object.\n\n    Args:\n      inputs: List of Tensors of shape [batch_size, height, width, in_channels].\n        Input activations to this layer.  List index is towers.\n      outputs_grads: List of Tensors, each of shape [batch_size,\n        height, width, out_channels], which are the gradients of the loss\n        with respect to the layer's outputs.  First index is source, second\n        index is tower.\n      filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,\n        out_channels). Represents shape of kernel used in this layer.\n      strides: The stride size in this layer (1-D Tensor of length 4).\n      padding: The padding in this layer (1-D of Tensor length 4).\n      data_format: None or str. Format of conv2d inputs.\n      dilations: None or tuple of 4 ints.\n      has_bias: Python bool. If True, the layer is assumed to have a bias\n        parameter in addition to its filter parameter.\n      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]\n        or None. If not None this is multiplied against the extracted patches\n        Tensor (broadcasting along the batch dimension) before statistics are\n        computed. (Default: None)\n\n    Raises:\n      ValueError: If inputs, output_grads, and filter_shape do not agree on\n        in_channels or out_channels.\n      ValueError: If strides, dilations are not length-4 lists of ints.\n      ValueError: If data_format does not put channel last.\n    \"\"\"\n    if not utils.is_data_format_channel_last(data_format):\n      raise ValueError(\"Channel must be last.\")\n    if any(input_.shape.ndims != 4 for input_ in inputs):\n      raise ValueError(\"inputs must be a list of 4-D Tensors.\")\n    if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):\n      raise ValueError(\"inputs and filter_shape must agree on in_channels.\")\n    for i, outputs_grad in enumerate(outputs_grads):\n      if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):\n        raise ValueError(\"outputs[%d] must be 4-D Tensor.\" % i)\n      if any(output_grad.shape.as_list()[-1] != filter_shape[-1]\n             for output_grad in outputs_grad):\n        raise ValueError(\n            \"outputs[%d] and filter_shape must agree on out_channels.\" % i)\n    if len(strides) != 4:\n      raise ValueError(\"strides must be length-4 list of ints.\")\n    if dilations is not None and len(dilations) != 4:\n      raise ValueError(\"dilations must be length-4 list of ints.\")\n\n    self._inputs = inputs\n    self._outputs_grads = outputs_grads\n    self._filter_shape = filter_shape\n    self._strides = strides\n    self._padding = padding\n    self._data_format = data_format\n    self._dilations = dilations\n    self._has_bias = has_bias\n    self._patches = None\n\n    self._patch_mask = patch_mask\n\n    super(ConvDiagonalFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_convdiag_\" + scope_string_from_params(\n        tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))\n\n  @property\n  def _cov_shape(self):\n    filter_height, filter_width, in_channels, out_channels = self._filter_shape\n    return [\n        filter_height * filter_width * in_channels + self._has_bias,\n        out_channels\n    ]\n\n  @property\n  def _num_sources(self):\n    return len(self._outputs_grads)\n\n  @property\n  def _num_towers(self):\n    return len(self._inputs)\n\n  @property\n  def _dtype(self):\n    return self._inputs[0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    return utils.get_shape(self._outputs_grads[source][tower])[0]\n\n  def make_covariance_update_op(self, ema_decay, ema_weight):\n    filter_height, filter_width, _, _ = self._filter_shape\n\n    # TODO(b/64144716): there is potential here for a big savings in terms\n    # of memory use.\n    if self._dilations is None:\n      rates = (1, 1, 1, 1)\n    else:\n      rates = tuple(self._dilations)\n\n    self._patches = []\n    for tower in range(self._num_towers):\n      with maybe_place_on_device(self._get_data_device(tower)):\n        patches = tf.extract_image_patches(\n            self._inputs[tower],\n            ksizes=[1, filter_height, filter_width, 1],\n            strides=self._strides,\n            rates=rates,\n            padding=self._padding)\n\n        if self._patch_mask is not None:\n          assert self._patch_mask.shape == self._filter_shape[0:-1]\n          # This should work as intended due to broadcasting.\n          patches *= self._patch_mask\n\n        if self._has_bias:\n          patches = append_homog(patches)\n\n        self._patches.append(patches)\n\n    return super(ConvDiagonalFactor, self).make_covariance_update_op(\n        ema_decay, ema_weight)\n\n  def _compute_new_cov(self, source, tower):\n    patches = self._patches[tower]\n    batch_size = utils.get_shape(patches)[0]\n\n    outputs_grad = self._outputs_grads[source][tower]\n\n    new_cov = self._convdiag_sum_of_squares(patches, outputs_grad)\n    new_cov /= tf.cast(batch_size, new_cov.dtype)\n\n    return new_cov\n\n  def _convdiag_sum_of_squares(self, patches, outputs_grad):\n    # This computes the sum of the squares of the per-training-case \"gradients\".\n    # It does this simply by computing a giant tensor containing all of these,\n    # doing an entry-wise square, and them summing along the batch dimension.\n    case_wise_gradients = tf.einsum(\"bijk,bijl->bkl\", patches, outputs_grad)\n    return tf.reduce_sum(tf.square(case_wise_gradients), axis=0)\n\n  def _get_data_device(self, tower):\n    return self._inputs[tower].device\n\n\nclass FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):\n  \"\"\"Kronecker factor for the input or output side of a fully-connected layer.\n  \"\"\"\n\n  def __init__(self,\n               tensors,\n               has_bias=False):\n    \"\"\"Instantiate FullyConnectedKroneckerFactor.\n\n    Args:\n      tensors: List of list of Tensors, each of shape [batch_size, n]. The\n        Tensors are typically either a layer's inputs or its output's gradients.\n        The first list index is source, the second is tower.\n      has_bias: bool. If True, append '1' to each row.\n    \"\"\"\n    # The tensor argument is either a tensor of input activations or a tensor of\n    # output pre-activation gradients.\n    self._has_bias = has_bias\n    self._tensors = tensors\n\n    self._one_hot_depth = getattr(self._tensors[0][0], \"one_hot_depth\", None)\n    if self._one_hot_depth is not None:\n      raise ValueError(\"Dense factors currently don't support 1-hot sparse \"\n                       \"data. Note that for input factors with such data, \"\n                       \"a diagonal approximation is exact (but the same is \"\n                       \"NOT true for output factors).\")\n\n    super(FullyConnectedKroneckerFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_fckron_\" + scope_string_from_params(\n        tuple(nest.flatten(self._tensors)) + (self._has_bias,))\n\n  @property\n  def _cov_shape(self):\n    size = self._tensors[0][0].shape[1] + self._has_bias\n    return [size, size]\n\n  @property\n  def _num_sources(self):\n    return len(self._tensors)\n\n  @property\n  def _num_towers(self):\n    return len(self._tensors[0])\n\n  @property\n  def _dtype(self):\n    return self._tensors[0][0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    return utils.get_shape(self._tensors[source][tower])[0]\n\n  def _compute_new_cov(self, source, tower):\n    tensor = self._tensors[source][tower]\n    if self._has_bias:\n      tensor = append_homog(tensor)\n    return compute_cov(tensor)\n\n  def _get_data_device(self, tower):\n    return self._tensors[0][tower].device\n\n\nclass ConvInputKroneckerFactor(DenseSquareMatrixFactor):\n  r\"\"\"Kronecker factor for the input side of a convolutional layer.\n\n  Estimates E[ a a^T ] where a is the inputs to a convolutional layer given\n  example x. Expectation is taken over all examples and locations.\n\n  Note that this is related to Omega in https://arxiv.org/abs/1602.01407 except\n  that here we normalize by the number of locations (k). By setting the\n  renormalization coefficient (\"_renorm_coeff\") in the block class to k we\n  get the same overall block approximation from the paper.\n  \"\"\"\n\n  def __init__(self,\n               inputs,\n               filter_shape,\n               padding,\n               strides=None,\n               dilation_rate=None,\n               data_format=None,\n               extract_patches_fn=None,\n               has_bias=False,\n               sub_sample_inputs=None,\n               sub_sample_patches=None,\n               patch_mask=None):\n    \"\"\"Initializes ConvInputKroneckerFactor.\n\n    Args:\n      inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,\n        in_channels]. Inputs to layer. List index is tower.\n      filter_shape: List of ints. Contains [..spatial_filter_size..,\n        in_channels, out_channels]. Shape of convolution kernel.\n      padding: str. Padding method for layer. \"SAME\" or \"VALID\".\n      strides: List of ints or None. Contains [..spatial_filter_strides..] if\n        'extract_patches_fn' is compatible with tf.nn.convolution(), else\n        [1, ..spatial_filter_strides, 1].\n      dilation_rate: List of ints or None. Rate for dilation along each spatial\n        dimension if 'extract_patches_fn' is compatible with\n        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].\n      data_format: str or None. Format of input data.\n      extract_patches_fn: str or None. Name of function that extracts image\n        patches. One of \"extract_convolution_patches\", \"extract_image_patches\",\n        \"extract_pointwise_conv2d_patches\".\n      has_bias: bool. If True, append 1 to in_channel.\n      sub_sample_inputs: `bool`. If True, then subsample the inputs from which\n        the image patches are extracted. (Default: None)\n      sub_sample_patches: `bool`, If `True` then subsample the extracted\n        patches. (Default: None)\n      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]\n        or None. If not None this is multiplied against the extracted patches\n        Tensor (broadcasting along the batch dimension) before statistics are\n        computed. (Default: None)\n    \"\"\"\n    self._inputs = inputs\n    self._filter_shape = filter_shape\n    self._strides = strides\n    self._padding = padding\n    self._dilation_rate = dilation_rate\n    self._data_format = data_format\n    self._extract_patches_fn = extract_patches_fn\n    self._has_bias = has_bias\n\n    if sub_sample_inputs is None:\n      self._sub_sample_inputs = _SUB_SAMPLE_INPUTS\n    else:\n      self._sub_sample_inputs = sub_sample_inputs\n\n    if sub_sample_patches is None:\n      self._sub_sample_patches = _SUB_SAMPLE_PATCHES\n    else:\n      self._sub_sample_patches = sub_sample_patches\n\n    self._patch_mask = patch_mask\n\n    super(ConvInputKroneckerFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_convinkron_\" + scope_string_from_params(\n        tuple(self._inputs) +\n        tuple((self._filter_shape, self._strides, self._padding,\n               self._dilation_rate, self._data_format, self._has_bias,\n               self._patch_mask)))\n\n  @property\n  def _cov_shape(self):\n    spatial_filter_shape = self._filter_shape[0:-2]\n    in_channels = self._filter_shape[-2]\n    size = np.prod(spatial_filter_shape) * in_channels + self._has_bias\n    return [size, size]\n\n  @property\n  def _num_sources(self):\n    return 1\n\n  @property\n  def _num_towers(self):\n    return len(self._inputs)\n\n  @property\n  def _dtype(self):\n    return self._inputs[0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    assert source == 0\n    return utils.get_shape(self._inputs[tower])[0]\n\n  def _compute_new_cov(self, source, tower):\n    assert source == 0\n\n    inputs = self._inputs[tower]\n    if self._sub_sample_inputs:\n\n      batch_size = inputs.shape.as_list()[0]\n      if batch_size is None:\n        # dynamic case:\n        batch_size = utils.get_shape(inputs)[0]\n        # computes: int(math.ceil(batch_size\n        #                               * _INPUTS_TO_EXTRACT_PATCHES_FACTOR))\n        new_size = tf.cast(\n            tf.ceil(tf.multiply(tf.cast(batch_size, dtype=tf.float32),\n                                _INPUTS_TO_EXTRACT_PATCHES_FACTOR)),\n            dtype=utils.preferred_int_dtype())\n      else:\n        # static case:\n        new_size = int(math.ceil(batch_size\n                                 * _INPUTS_TO_EXTRACT_PATCHES_FACTOR))\n\n      inputs = _random_tensor_gather(inputs, new_size)\n\n    # TODO(b/64144716): there is potential here for a big savings in terms of\n    # memory use.\n    if _USE_PATCHES_SECOND_MOMENT_OP:\n      raise NotImplementedError  # patches op is not available outside of Google,\n                                 # sorry! You'll need to turn it off to proceed.\n    else:\n      if self._extract_patches_fn in [None, \"extract_convolution_patches\"]:\n        patches = utils.extract_convolution_patches(\n            inputs,\n            self._filter_shape,\n            padding=self._padding,\n            strides=self._strides,\n            dilation_rate=self._dilation_rate,\n            data_format=self._data_format)\n\n      elif self._extract_patches_fn == \"extract_image_patches\":\n        assert inputs.shape.ndims == 4\n        assert len(self._filter_shape) == 4\n        assert len(self._strides) == 4, self._strides\n        if self._dilation_rate is None:\n          rates = [1, 1, 1, 1]\n        else:\n          rates = self._dilation_rate\n          assert len(rates) == 4\n          assert rates[0] == rates[-1] == 1\n        patches = tf.extract_image_patches(\n            inputs,\n            ksizes=[1] + list(self._filter_shape[0:-2]) + [1],\n            strides=self._strides,\n            rates=rates,\n            padding=self._padding)\n\n      elif self._extract_patches_fn == \"extract_pointwise_conv2d_patches\":\n        assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]\n        assert self._filter_shape[0] == self._filter_shape[1] == 1\n        patches = utils.extract_pointwise_conv2d_patches(\n            inputs, self._filter_shape, data_format=None)\n\n      else:\n        raise NotImplementedError(self._extract_patches_fn)\n\n      if self._patch_mask is not None:\n        assert self._patch_mask.shape == self._filter_shape[0:-1]\n        # This should work as intended due to broadcasting.\n        patches *= tf.reshape(self._patch_mask, [-1])\n\n      flatten_size = np.prod(self._filter_shape[0:-1])\n      # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde\n      # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),\n      # where M = minibatch size, |T| = number of spatial locations,\n      # |Delta| = number of spatial offsets, and J = number of input maps\n      # for convolutional layer l.\n      patches_flat = tf.reshape(patches, [-1, flatten_size])\n      # We append a homogenous coordinate to patches_flat if the layer has\n      # bias parameters. This gives us [[A_l]]_H from the paper.\n      if self._sub_sample_patches:\n        patches_flat = _subsample_patches(patches_flat)\n\n      if self._has_bias:\n        patches_flat = append_homog(patches_flat)\n      # We call compute_cov without passing in a normalizer. compute_cov uses\n      # the first dimension of patches_flat i.e. M|T| as the normalizer by\n      # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with\n      # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from\n      # the paper but has a different scale here for consistency with\n      # ConvOutputKroneckerFactor.\n      # (Tilde omitted over A for clarity.)\n      return compute_cov(patches_flat)\n\n  def _get_data_device(self, tower):\n    return self._inputs[tower].device\n\n\nclass ConvInputMultiKF(ConvInputKroneckerFactor):\n\n  def __init__(self,\n               inputs,\n               filter_shape,\n               padding,\n               num_uses,\n               strides=None,\n               dilation_rate=None,\n               data_format=None,\n               extract_patches_fn=None,\n               has_bias=False,\n               sub_sample_inputs=None,\n               sub_sample_patches=None,\n               patch_mask=None):\n\n    super(ConvInputMultiKF, self).__init__(inputs,\n                                           filter_shape,\n                                           padding,\n                                           strides=strides,\n                                           dilation_rate=dilation_rate,\n                                           data_format=data_format,\n                                           extract_patches_fn=extract_patches_fn,\n                                           has_bias=has_bias,\n                                           sub_sample_inputs=sub_sample_inputs,\n                                           sub_sample_patches=sub_sample_patches,\n                                           patch_mask=patch_mask)\n    self._num_uses = num_uses\n\n  def _partial_batch_size(self, source=0, tower=0):\n    # Note that some internal comptutations of \"batch_size\" done in the parent\n    # class won't actually be the proper batch size. Instead, they will be\n    # just \"the thing to normalize the statistics by\", essentially. This is okay\n    # as we don't mix the two things up.\n    return (super(ConvInputMultiKF, self)._partial_batch_size(source=source,\n                                                              tower=tower)\n            // self._num_uses)\n\n\nclass ConvInputSUAKroneckerFactor(FisherFactor):\n  r\"\"\"Kronecker factor for the input side of a convolutional layer.\n\n  Assumes activations across locations are uncorrelated. Check section 4.2\n  Theorem 4 in https://arxiv.org/pdf/1602.01407.pdf for further details on the\n  assumptions. This is a computationally more efficient approximation,\n  especially for very wide layers.\n  \"\"\"\n\n  def __init__(self, inputs, filter_shape, has_bias=False):\n    \"\"\"Initializes ConvInputSUAKroneckerFactor.\n\n    If `ASSUME_ZERO_MEAN_ACTIVATIONS` is `True` then assumes activations\n    zero mean and the contribution from `M(j) M(j')` term in\n    Theorem 4 from https://arxiv.org/pdf/1602.01407.pdf is ignored.\n\n    Args:\n      inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,\n        in_channels]. Inputs to layer. List index is tower.\n      filter_shape: List of ints. Contains [..spatial_filter_size..,\n        in_channels, out_channels]. Shape of convolution kernel.\n      has_bias: bool. If True, appends 1 to mean activations.\n    \"\"\"\n    self._inputs = inputs\n    self._filter_shape = filter_shape\n    self._has_bias = has_bias\n\n    self._kw_kh = np.prod(self._filter_shape[0:-2])\n    self._in_channels = self._filter_shape[-2]\n\n    self._matpower_by_exp_and_damping = OrderedDict()  # { (float, hashable): variable }\n    self._matpower_registrations = set()  # { (float, hashable) }\n    self._damping_funcs_by_id = OrderedDict()  # {hashable: lambda}\n    self._damping_var_by_id = OrderedDict()\n\n    if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n      self._cov_inv_mu_by_damping_id = OrderedDict()\n      self._rank_one_update_scale_by_damping_id = OrderedDict()\n\n    super(ConvInputSUAKroneckerFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_convinsuakron_\" + scope_string_from_params(\n        tuple(self._inputs) + tuple((self._filter_shape, self._has_bias)))\n\n  @property\n  def _cov_shape(self):\n    \"\"\"Returns a list with value [in_channels, in_channels].\n\n    NOTE: This does not return the shape of the full cov matrix. But returns the\n    shape of the matrix which computes the covariance of the input channel\n    activations under the assumption mentioned in Theorem 4 in\n    https://arxiv.org/pdf/1602.01407.pdf. This does not include bias dimension\n    and also includes only the `Sigma` term from Theorem 4 in\n    the paper.\n    \"\"\"\n    return [self._in_channels, self._in_channels]\n\n  @property\n  def _num_sources(self):\n    return 1\n\n  @property\n  def _num_towers(self):\n    return len(self._inputs)\n\n  @property\n  def _dtype(self):\n    return self._inputs[0].dtype\n\n  @property\n  def mu(self):\n    return self._mu.value\n\n  def _partial_batch_size(self, source=0, tower=0):\n    assert source == 0\n    return utils.get_shape(self._inputs[tower])[0]\n\n  def _register_damping(self, damping_func):\n    damping_id = graph_func_to_id(damping_func)\n    if damping_id not in self._damping_funcs_by_id:\n      self._damping_funcs_by_id[damping_id] = damping_func\n    return damping_id\n\n  def get_inv_vars(self):\n    inv_vars = []\n    inv_vars.extend(self._matpower_by_exp_and_damping.values())\n    return inv_vars\n\n  def instantiate_cov_variables(self):\n    \"\"\"Makes the internal cov variable(s).\"\"\"\n    super(ConvInputSUAKroneckerFactor,\n          self).instantiate_cov_variables()\n\n    # Create variables for computing the mean activations only if\n    # `ASSUME_ZERO_MEAN_ACTIVATIONS` is set to `False`. Otherwise the\n    # contribution from the second term in equation 35 in the paper\n    # https://arxiv.org/pdf/1602.01407.pdf is ignored.\n    if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n      with tf.variable_scope(self._var_scope):\n        self._mu = utils.MovingAverageVariable(\n            name=\"mu\",\n            shape=(self._in_channels, 1),  # number of input channels.\n            dtype=self._dtype,\n            initializer=tf.zeros_initializer(),\n            normalize_value=ZERO_DEBIAS)\n\n  def make_covariance_update_op(self, ema_decay, ema_weight):\n    \"\"\"Constructs and returns the covariance update Op.\n\n    Args:\n      ema_decay: The exponential moving average decay (float or Tensor).\n      ema_weight: float or Tensor. The weight to put on the newly computed\n        values. This is typically 1.0 - ema_decay.\n\n    Returns:\n      An Op for updating the covariance Variable referenced by _cov and possibly\n      updating mean activations.\n    \"\"\"\n\n    # The newly computed cov matrix is returned and assigned below to the\n    # moving average. `new_cov` is required to compute mean activations.\n    # Mean activations is given by last row and col of `new_cov.\n    # Remove the last row and col from `new_cov`.\n\n    new_cov = super(ConvInputSUAKroneckerFactor, self)._compute_total_new_cov()\n    new_mu = new_cov[:-1, -1:]\n    new_cov = new_cov[0:-1, 0:-1]\n\n    if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n      new_cov = new_cov - tf.matmul(new_mu, new_mu, transpose_b=True)\n\n      acc_mu_op = self._mu.add_to_average(new_mu, decay=ema_decay,\n                                          weight=ema_weight)\n    else:\n      acc_mu_op = tf.no_op()\n\n      if SUBTRACT_MEAN_CONTRIB_FROM_COV:\n        new_cov = new_cov - tf.matmul(new_mu, new_mu, transpose_b=True)\n\n    acc_cov_op = self._cov.add_to_average(new_cov, decay=ema_decay,\n                                          weight=ema_weight)\n    return tf.group(acc_cov_op, acc_mu_op)\n\n  def _compute_new_cov(self, source, tower):\n    assert source == 0\n    inputs = self._inputs[tower]\n    # Reshape inputs to compute [in_channels, in_channels] shape cov.\n    channel_inputs = tf.reshape(inputs, shape=(-1, self._in_channels))\n\n    # Append the bias dimension as we need this to calculate mean activations.\n    channel_inputs = append_homog(channel_inputs)\n\n    return compute_cov(channel_inputs)\n\n  def register_matpower(self, exp, damping_func):\n    \"\"\"Registers a matrix power to be maintained and served on demand.\n\n    This creates a variable and signals make_inverse_update_ops to make the\n    corresponding update op.  The variable can be read via the method\n    get_matpower.\n\n    Args:\n      exp: float.  The exponent to use in the matrix power.\n      damping_func: A function that computes a 0-D Tensor or a float which will\n        be the damping value used.  i.e. damping = damping_func().\n    \"\"\"\n    if exp == 1.0:\n      return\n\n    if exp != -1:\n      raise ValueError(\"ConvInputSUAKroneckerFactor supports only\"\n                       \"matrix inversion\")\n\n    damping_id = self._register_damping(damping_func)\n\n    if (exp, damping_id) not in self._matpower_registrations:\n      self._matpower_registrations.add((exp, damping_id))\n\n  def _compute_sm_rank_one_update_quants(self, exp, damping_id, damping_value):\n    \"\"\"Returns tensors to compute Fisher inv using Sherman-Morrison formula.\"\"\"\n\n    cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)]\n    cov_inv_mu = tf.matmul(cov_inv, self.mu)\n    hatmu_t_cov_inv_hatmu = self._kw_kh * tf.squeeze(\n        tf.matmul(self.mu, cov_inv_mu, transpose_a=True))\n\n    if self._has_bias:\n      tildemu_t_cov_inv_tildemu = hatmu_t_cov_inv_hatmu + (1. / damping_value)\n      return cov_inv_mu, (1. / (1. + tildemu_t_cov_inv_tildemu))\n    else:\n      return cov_inv_mu, (1. / (1. + hatmu_t_cov_inv_hatmu))\n\n  def get_matpower(self, exp, damping_func):\n    # Note that this function returns a variable which gets updated by the\n    # inverse ops.  It may be stale / inconsistent with the latest value of\n    # self.cov (except when exp == 1).\n    if exp == 1:\n      return self._make_cov_linear_operator(\n          damping=tf.cast(damping_func(), dtype=self._dtype))\n    elif exp == -1:\n      damping_id = graph_func_to_id(damping_func)\n      cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)]\n      damping_value = self._damping_var_by_id[damping_id]\n\n      # Replicates the in_channels * in_channels cov inverse matrix.\n      # Note that in this function the replications are not done explicitly.\n      # They are done using tf.linalg ops and hence they are computationally\n      # efficient.\n      quant_1 = tf.linalg.LinearOperatorKronecker([\n          tf.linalg.LinearOperatorFullMatrix(\n              cov_inv,\n              is_non_singular=True,\n              is_self_adjoint=True,\n              is_positive_definite=True,\n              is_square=True),\n          tf.linalg.LinearOperatorIdentity(\n              num_rows=self._kw_kh, dtype=self._dtype)\n      ])\n      # If a bias dimension needs to be appended then we need to expand\n      # scaled_cov_inv_mu and assign `1` to the last dimension. Also\n      # we need to append inverse of damping constant (1 * 1 matrix) to\n      # to the replicated cov inverse matrix.\n      if self._has_bias:\n        bias_operator = tf.linalg.LinearOperatorFullMatrix(\n            [[1. / damping_value]],\n            is_non_singular=True,\n            is_self_adjoint=True,\n            is_positive_definite=True,\n            is_square=True)\n        cov_inv_kron_identity_operator = tf.linalg.LinearOperatorBlockDiag(\n            [quant_1, bias_operator])\n\n        if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n          cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id]\n          scale = self._rank_one_update_scale_by_damping_id[damping_id]\n\n          # Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last\n          # dim and then reshape.\n          mean_update = (\n              tf.expand_dims(\n                  append_homog(\n                      tf.reshape(tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1,)),\n                      homog_value=(1. / damping_value)),\n                  axis=1))\n      else:\n        cov_inv_kron_identity_operator = quant_1\n\n        if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n          cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id]\n          scale = self._rank_one_update_scale_by_damping_id[damping_id]\n          # Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last\n          # dim and then reshape.\n          mean_update = tf.reshape(\n              tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1, 1))\n\n      if ASSUME_ZERO_MEAN_ACTIVATIONS:\n        return cov_inv_kron_identity_operator\n      else:\n        # To include the contribution from the mean activations we need to\n        # low rank update op. Note the Sherman Morrison formula requires\n        # negative of (mean_update * mean_update^T) / scale term to be added.\n        # In order to achieve this using `LinearOperatorLowRankUpdate` set `v`\n        # to negative of mean update vector multiplied by scale.\n        return tf.linalg.LinearOperatorLowRankUpdate(\n            cov_inv_kron_identity_operator,\n            mean_update,\n            v=-scale * mean_update,\n            is_non_singular=True,\n            is_self_adjoint=True,\n            is_positive_definite=True,\n            is_square=True)\n    else:\n      raise ValueError(\"ConvInputSUAKroneckerFactor only supports\"\n                       \"computing inverse of cov matrix.\")\n\n  def make_inverse_update_ops(self):\n    \"\"\"Creates and return update ops for registered computations.\"\"\"\n    inverse_ops = []\n    for (exp,\n         damping_id), matpower in self._matpower_by_exp_and_damping.items():\n      assert exp == -1\n\n      damping = tf.cast(self._damping_funcs_by_id[damping_id](), self._dtype)\n      damping_assign_op = utils.smart_assign(\n          self._damping_var_by_id[damping_id], damping)\n      inverse_op = utils.smart_assign(matpower,\n                                      utils.posdef_inv(self.cov, damping))\n      inverse_ops.append(damping_assign_op)\n\n      if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n        with tf.control_dependencies([inverse_op]):\n          (cov_inv_mu,\n           rank_one_update_scale) = self._compute_sm_rank_one_update_quants(\n               exp, damping_id, damping)\n\n          inverse_ops.append(\n              utils.smart_assign(self._cov_inv_mu_by_damping_id[damping_id],\n                                 cov_inv_mu))\n          inverse_ops.append(\n              utils.smart_assign(\n                  self._rank_one_update_scale_by_damping_id[damping_id],\n                  rank_one_update_scale))\n      else:\n        inverse_ops.append(inverse_op)\n\n    return inverse_ops\n\n  def get_inverse(self, damping_func):\n    # Just for backwards compatibility of some old code and tests\n    return self.get_matpower(-1, damping_func)\n\n  def instantiate_inv_variables(self):\n    \"\"\"Makes the internal \"inverse\" variable(s).\"\"\"\n\n    for (exp, damping_id) in self._matpower_registrations:\n      if exp != -1.:\n        raise ValueError(\"ConvInputSUAKroneckerFactor only supports inverse\"\n                         \"computation\")\n\n      exp_string = scalar_or_tensor_to_string(exp)\n      damping_func = self._damping_funcs_by_id[damping_id]\n      damping_string = graph_func_to_string(damping_func)\n      with tf.variable_scope(self._var_scope):\n        matpower = tf.get_variable(\n            \"matpower_exp{}_damp{}\".format(exp_string, damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n\n      assert (exp, damping_id) not in self._matpower_by_exp_and_damping\n      self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower\n\n      self._damping_var_by_id[damping_id] = tf.get_variable(\n          \"damping_var_{}_{}\".format(exp_string, damping_string),\n          initializer=tf.zeros_initializer(),\n          shape=(),\n          trainable=False,\n          dtype=self._dtype,\n          use_resource=True)\n\n      if not ASSUME_ZERO_MEAN_ACTIVATIONS:\n        self._cov_inv_mu_by_damping_id[damping_id] = tf.get_variable(\n            \"cov_inv_mu_{}_{}\".format(exp_string, damping_string),\n            initializer=tf.zeros_initializer(),\n            shape=(self._in_channels, 1),\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n\n        self._rank_one_update_scale_by_damping_id[damping_id] = tf.get_variable(\n            \"rank_one_update_scale_{}_{}\".format(exp_string, damping_string),\n            initializer=tf.zeros_initializer(),\n            shape=(),\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n\n  def _make_cov_linear_operator(self, damping=None):\n    \"\"\"Returns cov as a linear operator.\n\n    Args:\n      damping: Damping value tensor. If `damping` is not None then returns\n        damped covariance matrix.\n\n    Returns:\n      tf.linalg.LinearOperator instance.\n    \"\"\"\n    if damping is not None:\n      cov = self.cov + damping * tf.eye(self._cov_shape[0], dtype=self._dtype)\n    else:\n      cov = self.cov\n\n    cov_operator = tf.linalg.LinearOperatorKronecker([\n        tf.linalg.LinearOperatorFullMatrix(\n            cov, is_self_adjoint=True, is_square=True),\n        tf.linalg.LinearOperatorIdentity(\n            num_rows=self._kw_kh, dtype=self._dtype)\n    ])\n\n    if self._has_bias:\n      bias_value = damping if damping is not None else 0.\n      bias_operator = tf.linalg.LinearOperatorFullMatrix([[bias_value]],\n                                                         is_self_adjoint=True,\n                                                         is_square=True)\n      cov_operator = tf.linalg.LinearOperatorBlockDiag(\n          [cov_operator, bias_operator])\n\n    if ASSUME_ZERO_MEAN_ACTIVATIONS:\n      return cov_operator\n    else:\n      # self.mu kron 1's vec is computed below by tiling mu.\n      hatmu = tf.tile(self.mu, [1, self._kw_kh])\n\n      if self._has_bias:\n        tildemu = append_homog(tf.reshape(hatmu, (-1,)))\n        mean_update = tf.expand_dims(tildemu, axis=1)\n      else:\n        mean_update = tf.reshape(hatmu, (-1, 1))\n\n      return tf.linalg.LinearOperatorLowRankUpdate(\n          cov_operator, mean_update, is_self_adjoint=True, is_square=True)\n\n  def get_cov_as_linear_operator(self):\n    return self._make_cov_linear_operator()\n\n  def get_cholesky(self, damping_func):\n    raise NotImplementedError(\"ConvInputSUAKroneckerFactor does not support\"\n                              \"cholesky factorization\")\n\n  def get_cholesky_inverse(self, damping_func):\n    raise NotImplementedError(\"ConvInputSUAKroneckerFactor does not support\"\n                              \"cholesky inverse computation\")\n\n  def register_cholesky(self):\n    raise NotImplementedError(\"ConvInputSUAKroneckerFactor does not support\"\n                              \"cholesky factorization\")\n\n  def register_cholesky_inverse(self):\n    raise NotImplementedError(\"ConvInputSUAKroneckerFactor does not support\"\n                              \"cholesky inverse computation\")\n\n  def _get_data_device(self, tower):\n    return self._inputs[tower].device\n\n\nclass ConvOutputKroneckerFactor(DenseSquareMatrixFactor):\n  r\"\"\"Kronecker factor for the output side of a convolutional layer.\n\n  Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer\n  given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over\n  all examples and locations.\n\n  Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See\n  Section 3.1 Estimating the factors.\n  \"\"\"\n\n  def __init__(self, outputs_grads, data_format=None):\n    \"\"\"Initializes ConvOutputKroneckerFactor.\n\n    Args:\n      outputs_grads: List of list of Tensors. Each Tensor is of shape\n          [batch_size, ..spatial_input_size.., out_channels].  First list index\n          is source, the second is tower.\n      data_format: None or str. Format of outputs_grads.\n\n    Raises:\n      ValueError: If channels are not final dimension.\n    \"\"\"\n    if not utils.is_data_format_channel_last(data_format):\n      raise ValueError(\"Channel must be last.\")\n    self._out_channels = outputs_grads[0][0].shape.as_list()[-1]\n    self._outputs_grads = outputs_grads\n    super(ConvOutputKroneckerFactor, self).__init__()\n\n  @property\n  def _var_scope(self):\n    return \"ff_convoutkron_\" + scope_string_from_params(\n        nest.flatten(self._outputs_grads))\n\n  @property\n  def _cov_shape(self):\n    size = self._out_channels\n    return [size, size]\n\n  @property\n  def _num_sources(self):\n    return len(self._outputs_grads)\n\n  @property\n  def _num_towers(self):\n    return len(self._outputs_grads[0])\n\n  @property\n  def _dtype(self):\n    return self._outputs_grads[0][0].dtype\n\n  def _partial_batch_size(self, source=0, tower=0):\n    return utils.get_shape(self._outputs_grads[source][tower])[0]\n\n  def _compute_new_cov(self, source, tower):\n    outputs_grad = self._outputs_grads[source][tower]\n\n    # reshaped_tensor below is the matrix DS_l defined in the KFC paper\n    # (tilde omitted over S for clarity). It has shape M|T| x I, where\n    # M = minibatch size, |T| = number of spatial locations, and\n    # I = number of output maps for convolutional layer l.\n    reshaped_tensor = tf.reshape(outputs_grad, [-1, self._out_channels])\n    # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,\n    # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l\n    # as defined in the paper, with shape I x I.\n    # (Tilde omitted over S for clarity.)\n    return compute_cov(reshaped_tensor)\n\n  def _get_data_device(self, tower):\n    return self._outputs_grads[0][tower].device\n\n\nclass ConvOutputMultiKF(ConvOutputKroneckerFactor):\n\n  def __init__(self, outputs_grads, num_uses, data_format=None):\n    super(ConvOutputMultiKF, self).__init__(outputs_grads,\n                                            data_format=data_format)\n    self._num_uses = num_uses\n\n  def _partial_batch_size(self, source=0, tower=0):\n    # Note that some internal comptutations of \"batch_size\" done in the parent\n    # class won't actually be the proper batch size. Instead, they will be\n    # just \"the thing to normalize the statistics by\", essentially. This is okay\n    # as we don't mix the two things up.\n    return (super(ConvOutputMultiKF, self)._partial_batch_size(source=source,\n                                                               tower=tower)\n            // self._num_uses)\n\n\nclass FullyConnectedMultiKF(FullyConnectedKroneckerFactor):\n  \"\"\"Kronecker factor for a fully connected layer used multiple times.\"\"\"\n\n  def __init__(self,\n               tensors,\n               num_uses=None,\n               has_bias=False):\n    \"\"\"Constructs a new `FullyConnectedMultiKF`.\n\n    Args:\n      tensors: List of list of Tensors of shape, each of shape\n        [num_uses * batch_size, n], and is a reshape version of a Tensor of\n        shape [num_uses, batch_size, n]. Each of these tensors is usually a\n        layer's inputs or its output's gradients. The first list index is\n        sources, the second is towers.\n      num_uses: int. The number of time-steps / uses.\n      has_bias: bool. If True, '1' is appended to each row.\n    \"\"\"\n\n    self._num_uses = num_uses\n\n    self._cov_dt1 = None\n    self._acc_cov_dt1 = None\n    self._make_cov_dt1 = False\n    self._option1quants_by_damping = OrderedDict()\n    self._option2quants_by_damping = OrderedDict()\n    self._option1quants_registrations = set()\n    self._option2quants_registrations = set()\n\n    super(FullyConnectedMultiKF, self).__init__(tensors=tensors,\n                                                has_bias=has_bias)\n\n  @property\n  def _num_timesteps(self):\n    return self._num_uses\n\n  def _partial_batch_size(self, source=0, tower=0):\n    shape = utils.get_shape(self._tensors[source][tower])\n    if len(shape) == 2:\n      # the folded case\n      return shape[0] // self._num_timesteps\n    elif len(shape) == 3:\n      return shape[1]  # batch is the second dim\n\n  @property\n  def _var_scope(self):\n    return \"ff_fc_multi_\" + scope_string_from_params(\n        tuple(nest.flatten(self._tensors))\n        + (self._num_timesteps, self._has_bias,))\n\n  def get_inv_vars(self):\n    inv_vars = super(FullyConnectedMultiKF, self).get_inv_vars()\n    inv_vars.extend(self._option1quants_by_damping.values())\n    inv_vars.extend(self._option2quants_by_damping.values())\n    return inv_vars\n\n  def make_covariance_update_op(self, ema_decay, ema_weight):\n\n    op = super(FullyConnectedMultiKF, self).make_covariance_update_op(\n        ema_decay, ema_weight)\n\n    if self._cov_dt1 is not None:\n      new_cov_dt1_contribs = []\n      for source in range(self._num_sources):\n        for tower in range(self._num_towers):\n          with maybe_place_on_device(self._get_data_device(tower)):\n            new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source,\n                                                                  tower))\n\n      new_cov_dt1 = (tf.add_n(new_cov_dt1_contribs) / float(self._num_towers))\n\n      # See comments in FisherFactor.make_covariance_update_op() for details.\n      new_cov_dt1 = utils.all_average(new_cov_dt1)\n\n      op2 = self._cov_dt1.add_to_average(new_cov_dt1, decay=ema_decay,\n                                         weight=ema_weight)\n      # TODO(b/69112164):\n      # It's important that _cov and _cov_dt1 remain consistent with each\n      # other while the inverse ops are happening. How can we ensure this?\n      # We will need to add explicit synchronization for this to\n      # work with asynchronous training.\n      op = tf.group(op, op2)\n\n    return op\n\n  def _compute_new_cov(self, source, tower):\n    tensor = self._tensors[source][tower]\n    if len(tensor.shape) == 3:\n      tensor = tf.reshape(tensor, [-1, tensor.shape[2]])\n\n    if self._has_bias:\n      tensor = append_homog(tensor)\n    return compute_cov(tensor)\n\n  def _compute_new_cov_dt1(self, source, tower):  # pylint: disable=missing-docstring\n    tensor = self._tensors[source][tower]\n    if len(tensor.shape) == 3:\n      tensor = tf.reshape(tensor, [-1, tensor.shape[2]])\n\n    if self._has_bias:\n      # This appending is technically done twice (the other time is for\n      # _compute_new_cov())\n      tensor = append_homog(tensor)\n\n    total_len = utils.get_shape(tensor)[0]\n    batch_size = total_len // self._num_timesteps\n\n    tensor_present = tensor[:-batch_size, :]\n    tensor_future = tensor[batch_size:, :]\n\n    # We specify a normalizer for this computation to ensure a PSD Fisher\n    # block estimate.  This is equivalent to padding with zeros, as was done\n    # in Section B.2 of the appendix.\n    return compute_cov(\n        tensor_future, tensor_right=tensor_present, normalizer=total_len)\n\n  @property\n  def _cov_shape(self):\n    shape = self._tensors[0][0].shape\n    if len(shape) == 2:\n      size = shape[1] + self._has_bias\n    elif len(shape) == 3:\n      size = shape[2] + self._has_bias\n    return [size, size]\n\n  def _get_data_device(self, tower):\n    return self._tensors[0][tower].device\n\n  @property\n  def _vec_shape(self):\n    size = self._tensors[0][0].shape[1] + self._has_bias\n    return [size]\n\n  def get_option1quants(self, damping_func):\n    damping_id = graph_func_to_id(damping_func)\n    return self._option1quants_by_damping[damping_id]\n\n  def get_option2quants(self, damping_func):\n    damping_id = graph_func_to_id(damping_func)\n    return self._option2quants_by_damping[damping_id]\n\n  @property\n  def cov_dt1(self):\n    assert self._cov_dt1 is not None\n    return self._cov_dt1.value\n\n  def get_cov_vars(self):\n    cov_vars = super(FullyConnectedMultiKF, self).get_cov_vars()\n    if self._make_cov_dt1:\n      cov_vars += [self.cov_dt1]\n    return cov_vars\n\n  def register_cov_dt1(self):\n    self._make_cov_dt1 = True\n\n  def instantiate_cov_variables(self):\n    super(FullyConnectedMultiKF, self).instantiate_cov_variables()\n    assert self._cov_dt1 is None\n    if self._make_cov_dt1:\n      with tf.variable_scope(self._var_scope):\n        self._cov_dt1 = utils.MovingAverageVariable(\n            name=\"cov_dt1\",\n            shape=self._cov_shape,\n            dtype=self._dtype,\n            initializer=tf.zeros_initializer(),\n            normalize_value=ZERO_DEBIAS)\n\n  def register_option1quants(self, damping_func):\n    damping_id = self._register_damping(damping_func)\n    if damping_id not in self._option1quants_registrations:\n      self._option1quants_registrations.add(damping_id)\n\n  def register_option2quants(self, damping_func):\n    damping_id = self._register_damping(damping_func)\n    if damping_id not in self._option2quants_registrations:\n      self._option2quants_registrations.add(damping_id)\n\n  def instantiate_inv_variables(self):\n    super(FullyConnectedMultiKF, self).instantiate_inv_variables()\n\n    for damping_id in self._option1quants_registrations:\n      damping_func = self._damping_funcs_by_id[damping_id]\n      damping_string = graph_func_to_string(damping_func)\n      # It's questionable as to whether we should initialize with stuff like\n      # this at all.  Ideally these values should never be used until they are\n      # updated at least once.\n      with tf.variable_scope(self._var_scope):\n        Lmat = tf.get_variable(  # pylint: disable=invalid-name\n            \"Lmat_damp{}\".format(damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n        psi = tf.get_variable(\n            \"psi_damp{}\".format(damping_string),\n            initializer=tf.ones_initializer(),\n            shape=self._vec_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n\n      assert damping_id not in self._option1quants_by_damping\n      self._option1quants_by_damping[damping_id] = (Lmat, psi)\n\n    for damping_id in self._option2quants_registrations:\n      damping_func = self._damping_funcs_by_id[damping_id]\n      damping_string = graph_func_to_string(damping_func)\n      # It's questionable as to whether we should initialize with stuff like\n      # this at all.  Ideally these values should never be used until they are\n      # updated at least once.\n      with tf.variable_scope(self._var_scope):\n        Pmat = tf.get_variable(  # pylint: disable=invalid-name\n            \"Lmat_damp{}\".format(damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n        Kmat = tf.get_variable(  # pylint: disable=invalid-name\n            \"Kmat_damp{}\".format(damping_string),\n            initializer=inverse_initializer,\n            shape=self._cov_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n        mu = tf.get_variable(\n            \"mu_damp{}\".format(damping_string),\n            initializer=tf.ones_initializer(),\n            shape=self._vec_shape,\n            trainable=False,\n            dtype=self._dtype,\n            use_resource=True)\n\n      assert damping_id not in self._option2quants_by_damping\n      self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)\n\n  def make_inverse_update_ops(self):\n    \"\"\"Create and return update ops corresponding to registered computations.\"\"\"\n    # TODO(b/69918258): Add correctness tests for this method.\n    # pylint: disable=invalid-name\n\n    ops = []\n\n    if (len(self._option1quants_by_damping) +\n        len(self._option2quants_by_damping)):\n\n      # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from\n      # the pseudo-code in the original paper.  Because the computations for\n      # the A and G case are essentially the same they can both be performed by\n      # the same class (this one).\n\n      C1 = self.cov_dt1\n\n      # Get the eigendecomposition of C0  (= self.cov)\n      eigen_e, eigen_V = self.get_eigendecomp()\n\n      # TODO(b/69678661): Note, there is an implicit assumption here that C1\n      # and C0 (as represented here by its eigen-decomp) are consistent.  This\n      # could fail to be the case if self._cov and self._cov_dt1 are not updated\n      # consistently, or are somehow read between or during the cov updates.\n      # Can this possibly happen?  Is there a way to prevent it?\n\n      for damping_id, (Lmat_var,\n                       psi_var) in self._option1quants_by_damping.items():\n\n        damping = self._damping_funcs_by_id[damping_id]()\n        damping = tf.cast(damping, self._dtype)\n\n        invsqrtC0 = tf.matmul(\n            eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)\n\n        # Might need to enforce symmetry lost due to numerical issues.\n        invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0\n\n        # The following line imposes the symmetry assumed by \"Option 1\" on C1.\n        # Strangely the code can work okay with this line commented out,\n        # depending on how psd_eig is defined.  I'm not sure why.\n        C1 = (C1 + tf.transpose(C1)) / 2.0\n\n        # hPsi = C0^(-1/2) * C1 * C0^(-1/2)  (hPsi means hat{Psi})\n        hPsi = tf.matmul(tf.matmul(invsqrtC0, C1), invsqrtC0)\n\n        # Compute the decomposition U*diag(psi)*U^T = hPsi\n        psi, U = utils.posdef_eig(hPsi)\n\n        # L = C0^(-1/2) * U\n        Lmat = tf.matmul(invsqrtC0, U)\n\n        ops.append(utils.smart_assign(Lmat_var, Lmat))\n        ops.append(utils.smart_assign(psi_var, psi))\n\n      for damping_id, (Pmat_var, Kmat_var,\n                       mu_var) in self._option2quants_by_damping.items():\n\n        damping = self._damping_funcs_by_id[damping_id]()\n        damping = tf.cast(damping, self._dtype)\n\n        # compute C0^(-1/2)\n        invsqrtC0 = tf.matmul(\n            eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)\n\n        # Might need to enforce symmetry lost due to numerical issues.\n        invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0\n\n        # Compute the product C0^(-1/2) * C1\n        invsqrtC0C1 = tf.matmul(invsqrtC0, C1)\n\n        # hPsi = C0^(-1/2) * C1 * C0^(-1/2)  (hPsi means hat{Psi})\n        hPsi = tf.matmul(invsqrtC0C1, invsqrtC0)\n\n        # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi\n        # Note that we using the notation mu instead of \"m\" for the eigenvalues.\n        # Instead of computing the product hPsi^T * hPsi and then doing an\n        # eigen-decomposition of this we just compute the SVD of hPsi and then\n        # square the singular values to get the eigenvalues. For a justification\n        # of this approach, see:\n        # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition\n        sqrtmu, _, E = tf.svd(hPsi)\n        mu = tf.square(sqrtmu)\n\n        # Mathematically, the eigenvalues should not should not exceed 1.0, but\n        # due to numerical issues, or possible issues with inconsistent\n        # values of C1 and (the eigen-decomposition of) C0 they might. So\n        # we enforce this condition.\n        mu = tf.minimum(mu, 1.0)\n\n        # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)\n        Pmat = tf.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)\n\n        # K = C_0^(-1/2) * E\n        Kmat = tf.matmul(invsqrtC0, E)\n\n        ops.append(utils.smart_assign(Pmat_var, Pmat))\n        ops.append(utils.smart_assign(Kmat_var, Kmat))\n        ops.append(utils.smart_assign(mu_var, mu))\n\n    ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()\n    return [tf.group(*ops)]\n\n    # pylint: enable=invalid-name\n"
  },
  {
    "path": "kfac/python/ops/kfac_utils/__init__.py",
    "content": ""
  },
  {
    "path": "kfac/python/ops/kfac_utils/async_inv_cov_update_kfac_opt.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Implementation of KFAC which runs cov and inv ops asynchronously.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport threading\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import optimizer\n\n_MAX_NUM_COV_INV_UPDATE_THREADS = 10\n\n\nclass AsyncInvCovUpdateKfacOpt(optimizer.KfacOptimizer):\n  \"\"\"Provides functionality to run cov and inv ops asynchronously.\n\n  The update ops are placed on devices in a round robin manner. These ops are\n  run asynchronously in the sense that the training op and cov and inv matrix\n  matrix computations are run independently of each other. The cov and inv\n  ops are run in background by threads.\n\n  Example usage:\n   opt = DedicatedInvCovUpdateKfacOpt(cov_devices=[\"/gpu:0\"],\n           inv_devices=[\"/gpu:1\"])\n   train_op = opt.minimize(loss)\n   with tf.Session() as sess:\n     opt.run_cov_inv_ops(sess)\n     for _ in range(100):\n       sess.run([train_op])\n     opt.stop_cov_inv_ops(sess)\n  \"\"\"\n\n  def __init__(self,\n               cov_devices,\n               inv_devices,\n               num_cov_inv_update_threads=None,\n               **kwargs):\n    \"\"\"Initializes AsyncInvCovUpdateKfacOpt.\n\n    See the docstring for `KfacOptimizer` class (in optimizer.py) for\n    complete list of arguments (there are many!).\n\n    Args:\n      cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance\n        computations will be placed on these devices in a round-robin fashion.\n        Can be None, which means that no devices are specified.\n      inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion\n        computations will be placed on these devices in a round-robin fashion.\n        Can be None, which means that no devices are specified.\n      num_cov_inv_update_threads: `int`, Number of parallel computations of\n        inverse and covariance ops. If a value is not passed then the number of\n        threads will be set to half of length of number of ops to run\n        asynchronously (Capped at `_MAX_NUM_COV_INV_UPDATE_THREADS`).\n        (Default: None)\n      **kwargs: Arguments to `KfacOptimizer` class.\n    \"\"\"\n    self.next_op = None\n    self._coord = None\n    self._num_cov_inv_update_threads = num_cov_inv_update_threads\n    self._threads = None\n    super(AsyncInvCovUpdateKfacOpt, self).__init__(\n        placement_strategy=\"round_robin\", **kwargs)\n\n  def _make_ops(self, update_thunks):\n    return [thunk() for thunk in update_thunks]\n\n  def apply_gradients(self, grads_and_vars, global_step=None, name=None):\n    cov_update_thunks, inv_update_thunks = self.make_vars_and_create_op_thunks()\n    apply_grads = super(AsyncInvCovUpdateKfacOpt,\n                        self).apply_gradients(\n                            grads_and_vars=grads_and_vars,\n                            global_step=global_step,\n                            name=name)\n    self._set_up_op_name_queue(\n        self._make_ops(cov_update_thunks + inv_update_thunks))\n    return apply_grads\n\n  def run_cov_inv_ops(self, sess):\n    \"\"\"Starts threads to run covariance and inverse ops.\"\"\"\n    self._coord = tf.train.Coordinator()\n    self._threads = [\n        threading.Thread(target=self._run_ops, args=(\n            (sess,)\n        )) for _ in range(self._num_cov_inv_update_threads)\n    ]\n    for t in self._threads:\n      t.start()\n\n  def _run_ops(self, sess):\n    \"\"\"Runs the covariance and inverse ops.\n\n    Each thread gets the next op name to run from the shared dataset that is\n    created in `_set_up_op_name_queue` method. The opname is mapped to the\n    op which is run in thread context.\n\n    Args:\n      sess: `tf.Session` instance.\n    \"\"\"\n    while not self._coord.should_stop():\n      next_op_name = sess.run(self._next_op_name).decode(\"ascii\")\n      next_op = self._ops_by_name[next_op_name]\n      sess.run(next_op)\n\n  def stop_cov_inv_ops(self, sess):\n    \"\"\"Signals coordinator to stop and waits for threads to terminate.\"\"\"\n    self._coord.request_stop()\n    self._coord.join(self._threads)\n\n  def _set_up_op_name_queue(self, ops_to_run):\n    \"\"\"Sets up a queue of op names.\n\n    Convert the names of ops to run to tensors and creates a dataset of names.\n    The op name tensors in the Dataset are repeated indefinitely. Running\n    `self._next_op_name` returns the name of the next op to execute.\n\n    Args:\n      ops_to_run: `List` of ops to run asynchronously.\n    \"\"\"\n    self._num_cov_inv_update_threads = self._num_cov_inv_update_threads or max(\n        int(len(ops_to_run) / 2), _MAX_NUM_COV_INV_UPDATE_THREADS)\n    self._ops_by_name = {op.name: op for op in ops_to_run}\n    op_names = tf.convert_to_tensor(list(sorted(op.name for op in ops_to_run)))\n    op_names_dataset = tf.data.Dataset.from_tensor_slices(op_names).repeat()\n    self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()\n"
  },
  {
    "path": "kfac/python/ops/kfac_utils/data_reader.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Reads variable size batches of data from a data set and stores read data.\n\n`VariableBatchReader` reads variable size data from a dataset.\n`CachedDataReader` on top of `VariableBatchReader` adds functionality to store\nthe read batch for use in the next session.run() call.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\n\ndef _slice_data(stored_data, size):\n  return [data[:size] for data in stored_data]\n\n\nclass VariableBatchReader(object):\n  \"\"\"Read data of varying batch sizes from a data set.\"\"\"\n\n  def __init__(self, dataset, max_batch_size):\n    \"\"\"Initializes class.\n\n    Args:\n      dataset: List of Tensors representing the dataset, shuffled, repeated,\n        and batched into mini-batches of size at least `max_batch_size`.  In\n        other words it should be reshuffled at each session.run call.  This can\n        be done with the tf.data package using the construction demonstrated in\n        load_mnist() function in examples/autoencoder_auto_damping.py.\n      max_batch_size: `int`. Maximum batch size of the data that can be\n        retrieved from the data set.\n    \"\"\"\n    self._dataset = dataset\n    self._max_batch_size = max_batch_size\n\n  def __call__(self, batch_size):\n    \"\"\"Reads `batch_size` data.\n\n    Args:\n      batch_size: Tensor of type `int32`, batch size of the data to be\n        retrieved from the dataset. `batch_size` should be less than or\n        equal to `max_batch_size`.\n\n    Returns:\n       Read data, An iterable of tensors with batch size equal to `batch_size`.\n    \"\"\"\n    check_size = tf.assert_less_equal(\n        batch_size,\n        tf.convert_to_tensor(self._max_batch_size, dtype=tf.int32),\n        message='Data set read failure, Batch size greater than max allowed.'\n    )\n    with tf.control_dependencies([check_size]):\n      return _slice_data(self._dataset, batch_size)\n\n\nclass CachedDataReader(VariableBatchReader):\n  \"\"\"Provides functionality to store variable batch size data.\"\"\"\n\n  def __init__(self, dataset, max_batch_size):\n    \"\"\"Initializes class and creates variables for storing previous batch.\n\n    Args:\n      dataset: List of Tensors representing the dataset, shuffled, repeated,\n        and batched into mini-batches of size at least `max_batch_size`.  In\n        other words it should be reshuffled at each session.run call.  This can\n        be done with the tf.data package using the construction demonstrated in\n        load_mnist() function in examples/autoencoder_auto_damping.py.\n      max_batch_size: `int`. Maximum batch size of the data that can be\n        retrieved from the data set.\n    \"\"\"\n    super(CachedDataReader, self).__init__(dataset, max_batch_size)\n    with tf.variable_scope('cached_data_reader'):\n      self._cached_batch_storage = [\n          tf.get_variable(\n              name='{}{}'.format('cached_batch_storage_', i),\n              shape=[max_batch_size]+ var.shape.as_list()[1:],\n              dtype=var.dtype,\n              trainable=False,\n              use_resource=True) for i, var in enumerate(self._dataset)\n      ]\n      self._cached_batch_size = tf.get_variable(\n          name='cached_batch_size', shape=(), dtype=tf.int32, trainable=False,\n          use_resource=True)\n\n      self._cached_batch = _slice_data(self._cached_batch_storage,\n                                       self._cached_batch_size)\n\n  def __call__(self, batch_size):\n    \"\"\"Reads `batch_size` data and stores the read batch.\n\n    Args:\n      batch_size: Tensor of type `int32`, batch size of the data to be\n        retrieved from the dataset. `batch_size` should be less than or\n        equal to `max_batch_size`.\n\n    Returns:\n       Read data, An iterable of tensors with batch size equal to `batch_size`.\n    \"\"\"\n    sliced_data = super(CachedDataReader, self).__call__(batch_size)\n\n    # We need to make sure we read the cached batch before we update it!\n    with tf.control_dependencies(self._cached_batch):\n      batch_size_assign_op = self._cached_batch_size.assign(batch_size)\n      data_assign_ops = [\n          prev[:batch_size].assign(cur)  # yes, this actually works\n          for prev, cur in zip(self._cached_batch_storage, sliced_data)\n      ]\n      with tf.control_dependencies(data_assign_ops + [batch_size_assign_op]):\n        return [tf.identity(sdata) for sdata in sliced_data]\n\n  @property\n  def cached_batch(self):\n    return self._cached_batch\n"
  },
  {
    "path": "kfac/python/ops/kfac_utils/data_reader_alt.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Reads variable size batches of data from a data set and stores read data.\n\n`VariableBatchReader` reads variable size data from a dataset.\n`CachedDataReader` on top of `VariableBatchReader` adds functionality to store\nthe read batch for use in the next session.run() call.\n\nThis file is similar to data_reader.py but uses an alternative implementation\nthat requires the whole dataset to be passed in. This will often be faster than\nusing the original implementation with a very large max_batch_size.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\n\ndef _extract_data(tensor_list, indices):\n  return [tf.gather(tensor, indices, axis=0) for tensor in tensor_list]\n\n\nclass VariableBatchReader(object):\n  \"\"\"Read data of varying batch sizes from a data set.\"\"\"\n\n  def __init__(self, dataset, num_examples):\n    \"\"\"Initializes class.\n\n    Args:\n      dataset: List of Tensors. These must remain constant across session.run\n        calls, unlike the version of VariableBatchReader in data_reader.py.\n      num_examples: The number of examples in the data set (i.e. dimension 0\n        of the elements of `dataset`).\n    \"\"\"\n    self._dataset = dataset\n    self._num_examples = num_examples\n    self._indices = None\n\n  def __call__(self, batch_size):\n    \"\"\"Reads `batch_size` data.\n\n    Args:\n      batch_size: Tensor of type `int32`. Batch size of the data to be\n        retrieved from the dataset. `batch_size` should be less than or\n        equal to the number of examples in the dataset.\n\n    Returns:\n       Read data, a list of Tensors with batch size equal to `batch_size`.\n    \"\"\"\n    check_size = tf.assert_less_equal(\n        batch_size,\n        tf.convert_to_tensor(self._num_examples, dtype=tf.int32),\n        message='Data set read failure, batch_size > num_examples.'\n    )\n    with tf.control_dependencies([check_size]):\n      self._indices = tf.random.shuffle(\n          tf.range(self._num_examples, dtype=tf.int32))\n      return _extract_data(self._dataset, self._indices[:batch_size])\n\n\nclass CachedDataReader(VariableBatchReader):\n  \"\"\"Provides functionality to store variable batch size data.\"\"\"\n\n  def __init__(self, dataset, num_examples):\n    \"\"\"Initializes class and creates variables for storing previous batch.\n\n    Args:\n      dataset: List of Tensors. These must remain constant across session.run\n        calls, unlike the version of VariableBatchReader in data_reader.py.\n      num_examples: The number of examples in the data set (i.e. dimension 0\n        of the elements of `dataset`).\n    \"\"\"\n    super(CachedDataReader, self).__init__(dataset, num_examples)\n\n    self._cached_batch_indices = tf.get_variable(\n        name='cached_batch_indices',\n        shape=[self._num_examples],\n        dtype=tf.int32,\n        trainable=False,\n        use_resource=True)\n\n    self._cached_batch_size = tf.get_variable(\n        name='cached_batch_size', shape=(), dtype=tf.int32, trainable=False,\n        use_resource=True)\n\n    self._cached_batch = _extract_data(\n        self._dataset,\n        self._cached_batch_indices[:self._cached_batch_size])\n\n  def __call__(self, batch_size):\n    \"\"\"Reads `batch_size` data and stores the read batch.\n\n    Args:\n      batch_size: Tensor of type `int32`, batch size of the data to be\n        retrieved from the dataset. `batch_size` should be less than or\n        equal to `max_batch_size`.\n\n    Returns:\n       Read data, An iterable of tensors with batch size equal to `batch_size`.\n    \"\"\"\n    tensor_list = super(CachedDataReader, self).__call__(batch_size)\n\n    with tf.control_dependencies(self._cached_batch):\n      indices_assign_op = self._cached_batch_indices.assign(self._indices)\n      batch_size_assign_op = tf.assign(self._cached_batch_size, batch_size)\n\n      with tf.control_dependencies([indices_assign_op, batch_size_assign_op]):\n        return [tf.identity(tensor) for tensor in tensor_list]\n\n  @property\n  def cached_batch(self):\n    return self._cached_batch\n\n"
  },
  {
    "path": "kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Implementation of KFAC which runs covariance and inverse ops periodically.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\n\nfrom absl import logging\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import optimizer\nfrom kfac.python.ops import utils\n\n\nclass PeriodicInvCovUpdateKfacOpt(optimizer.KfacOptimizer):\n  \"\"\"Provides functionality to run covariance and inverse ops periodically.\n\n  Creates KFAC optimizer with a `placement strategy`.\n  Also runs the covariance and inverse ops periodically. The base class\n  does not provide a mechanism to automatically construct and run the covariance\n  and inverse ops, they must be created and run manually using\n  make_vars_and_create_op_thunks or create_ops_and_vars_thunks. This class\n  provides functionality to create these ops and runs them periodically whenever\n  optimizer.minimize op is run.\n\n  The inverse ops are run `invert_every` iterations and covariance statistics\n  are updated `cov_update_every` iterations. Ideally set\n  the `invert_every` to a multiple of `cov_update_every` so that the\n  inverses are computed after the covariance is updated. The higher the multiple\n  more the delay in using the computed covariance estimates in the KFAC update\n  step. Also computing the statistics and inverses periodically saves on\n  computation cost and a \"reasonable\" value often does not show any perforamnce\n  degradation compared to computing these quantitites every iteration.\n  \"\"\"\n\n  def __init__(self,\n               invert_every=10,\n               cov_update_every=1,\n               num_burnin_steps=0,\n               **kwargs):\n    \"\"\"Initializes a PeriodicInvCovUpdateKfacOptimizer object.\n\n    See the docstring for `KfacOptimizer` class (in optimizer.py) for\n    complete list of arguments (there are many!).\n\n    Please keep in mind that while the K-FAC code loosely conforms to\n    TensorFlow's Optimizer API, it can't be used naively as a \"drop in\n    replacement\" for basic classes like MomentumOptimizer.  Using it\n    properly with SyncReplicasOptimizer, for example, requires special care.\n\n    See the various examples in the \"examples\" directory for a guide about\n    how to use K-FAC in various contexts and various systems, like\n    TF-Estimator. See in particular the convnet example.  google/examples\n    also contains an example using TPUEstimator.\n\n    Note that not all use cases will work with\n    PeriodicInvCovUpdateKfacOptimizer. Sometimes you will have to use the base\n    KfacOptimizer which provides more fine-grained control over ops.  Other\n    times you might want to use one of the other subclassed optimizers like\n    AsyncInvCovUpdateKfacOpt.\n\n    Args:\n      invert_every: int. The inversion ops are run once every `invert_every`\n        executions of the training op. (Default: 10)\n      cov_update_every: int. The 'covariance update ops' are run once every\n        `covariance_update_every` executions of the training op. (Default: 1)\n      num_burnin_steps: int. For the first `num_burnin_steps` steps the\n        optimizer will only perform cov updates. Note: this doesn't work with\n        CrossShardOptimizer, since the custom minimize method implementation\n        will be ignored, or with MirroredStrategy, due to behavior of\n        conditional parameter updates with multiple replicas. (Default: 0)\n      **kwargs: Arguments to `KfacOptimizer` class.\n\n    Raises:\n      ValueError: if num_burnin_steps is non-zero and MirroredStrategy is being\n      used.\n    \"\"\"\n\n    if \"cov_ema_decay\" in kwargs:\n      kwargs[\"cov_ema_decay\"] = kwargs[\"cov_ema_decay\"]**cov_update_every\n\n    super(PeriodicInvCovUpdateKfacOpt, self).__init__(**kwargs)\n\n    self._invert_every = invert_every\n    self._cov_update_every = cov_update_every\n    self._num_burnin_steps = num_burnin_steps\n\n    self._made_vars_already = False\n\n    if self._adapt_damping:\n      if self._damping_adaptation_interval % self._invert_every != 0:\n        logging.warning(\"WARNING: damping_adaptation_interval isn't divisible \"\n                        \"by invert_every.\")\n\n    if (tf.distribute.has_strategy() and tf.distribute.get_replica_context()):\n      strategy = tf.distribute.get_strategy()\n      if (isinstance(strategy, tf.distribute.MirroredStrategy) and\n          self._num_burnin_steps > 0):\n        raise ValueError(\"num_burnin_steps must be 0 with MirroredStrategy.\")\n\n    with tf.variable_scope(self.get_name()):\n      self._burnin_counter = tf.get_variable(\n          \"burnin_counter\", dtype=tf.int64, shape=(), trainable=False,\n          initializer=tf.zeros_initializer, use_resource=True,\n          aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)\n\n  def minimize(self,\n               loss,\n               global_step=None,\n               var_list=None,\n               gate_gradients=tf.train.Optimizer.GATE_OP,\n               aggregation_method=None,\n               colocate_gradients_with_ops=True,\n               name=None,\n               grad_loss=None,\n               **kwargs):\n    # This method has the same general arguments as the minimize methods in\n    # standard optimizers do.\n\n    if not self._made_vars_already:\n      cov_update_thunks, _ = self.make_vars_and_create_op_thunks()\n    else:\n      (_, cov_update_thunks, _, _) = self.create_ops_and_vars_thunks()\n\n    self._made_vars_already = True\n\n    def update_cov_and_burnin_counter():\n      cov_update = tf.group(*(thunk(should_decay=False)\n                              for thunk in cov_update_thunks))\n\n      burnin_counter_update = self._burnin_counter.assign(\n          self._burnin_counter + 1)\n\n      return tf.group(cov_update, burnin_counter_update)\n\n    def super_minimize():\n      return super(PeriodicInvCovUpdateKfacOpt, self).minimize(\n          loss,\n          global_step=global_step,\n          var_list=var_list,\n          gate_gradients=gate_gradients,\n          aggregation_method=aggregation_method,\n          colocate_gradients_with_ops=colocate_gradients_with_ops,\n          name=name,\n          grad_loss=grad_loss,\n          **kwargs)\n\n    if self._num_burnin_steps == 0:\n      return super_minimize()\n    else:\n      return tf.cond(self._burnin_counter < self._num_burnin_steps,\n                     update_cov_and_burnin_counter, super_minimize)\n\n  def apply_gradients(self, grads_and_vars, global_step=None, name=None):\n    with tf.control_dependencies([self.kfac_update_ops()]):\n      return super(PeriodicInvCovUpdateKfacOpt, self).apply_gradients(\n          grads_and_vars=grads_and_vars,\n          global_step=global_step,\n          name=name)\n\n  def kfac_update_ops(self):\n    \"\"\"Sets up the KFAC factor update ops.\n\n    Returns:\n      An op that when run will run the update ops at their update frequencies.\n    \"\"\"\n    # This if-statement is a trick/hack to maintain compatibility with\n    # CrossShardOptimizer or other optimizers that might not call our\n    # custom minimize() method (that would otherwise always make the variables).\n    if not self._made_vars_already:\n      (cov_update_thunks,\n       inv_update_thunks) = self.make_vars_and_create_op_thunks()\n      logging.warning(\"It looks like apply_gradients() was called before \"\n                      \"minimze() was called. This is not recommended, and you \"\n                      \"should avoid using optimizer wrappers like \"\n                      \"CrossShardOptimizer with K-FAC that try to bypass the \"\n                      \"minimize() method. The burn-in feature won't work when \"\n                      \"the class is used this way, for example. And K-FAC does \"\n                      \"its own cross-relica syncronization.\")\n    else:\n      (_, cov_update_thunks,\n       _, inv_update_thunks) = self.create_ops_and_vars_thunks()\n\n    should_do_cov_updates = tf.equal(tf.mod(self.counter,\n                                            self._cov_update_every), 0)\n    maybe_cov_updates = utils.smart_cond(\n        should_do_cov_updates,\n        lambda: tf.group(*(thunk() for thunk in cov_update_thunks)),\n        tf.no_op)\n\n    maybe_pre_update_adapt_damping = self.maybe_pre_update_adapt_damping()\n    with tf.control_dependencies([maybe_cov_updates,\n                                  maybe_pre_update_adapt_damping]):\n      should_do_inv_updates = tf.equal(tf.mod(self.counter,\n                                              self._invert_every), 0)\n      maybe_inv_updates = utils.smart_cond(\n          should_do_inv_updates,\n          lambda: tf.group(*(thunk() for thunk in inv_update_thunks)),\n          tf.no_op)\n      return maybe_inv_updates\n"
  },
  {
    "path": "kfac/python/ops/layer_collection.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Registry for layers and their parameters/variables.\n\nThis represents the collection of all layers in the approximate Fisher\ninformation matrix to which a particular FisherBlock may belong. That is, we\nmight have several layer collections for one TF graph (if we have multiple K-FAC\noptimizers being used, for example.)\n\nThe model and loss function are registered using the register_XXX() methods.\nA subset of the layer types can be handled with the auto_register_layers()\nmethod.\n\nNote that the data formats in the docstrings for the register_XXX() methods\nmust be strictly adhered to. So for example, if a method asks for a Tensor of\nshape [batch_size, ...], then the first dimension must be the batch size and\nnothing else.  And the tensors must contain actual data, not a mixture of real\nand fake data / zeros generated by mini-batch padding, for example.  (Padding\nis only fine if it's treated as regular data by both your model and loss\nfunction. e.g. adding \"blank tokens\" at the end of a sequence which the model\nis still expected to predict.) If a method asks for the  parameters of a layer\nthen they must be the actual variable object(s) for said parameters, not a\ntensor formed by reshaping, re-casting, or tranposing its value.\n\nIf the internal data format used by your model isn't natively supported by\nthis system, you shouldn't try to crow-bar the arguments of the registration\nmethods until they seem to fit. Although the K-FAC code tries to protect\nagainst some common mistakes, it may often seem to run fine with incorrect\nregistrations, generating no exceptions or errors. But this will almost\ncertainly lead to (potentially severe) underperformance of the method.\n\nIf you have model code that doesn't represent tensors in the format expected\nby K-FAC, one thing you can try is introducing transformations that perform the\nconversion back and forth. But make sure the format that you convert to is\nactually valid according to the strict specifications of the registration\nfunction docstrings (e.g. that batch_size really is the mini-batch size, etc).\n\nSo if \"x\" is some data needed in the registration function that isn't of the\ncorrect format, you can try something like the following:\n\nx_transformed = transform(x)\nlc.register_XXX(x_transformed)\nx = untransform(x_transformed)\n...use x in rest of model...\n\nNote that without \"x = untransform(x_transformed)\" this often won't work since\nx_transformed won't be part of the model's forward graph, which is something\nK-FAC needs (especially for the \"output\" arguments of layers).\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom collections import defaultdict\nfrom collections import OrderedDict\nfrom contextlib import contextmanager\nfrom functools import partial\nimport math\n\n# Dependency imports\nimport six\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.util import nest\nfrom kfac.python.ops import fisher_blocks as fb\nfrom kfac.python.ops import loss_functions as lf\nfrom kfac.python.ops import utils\nfrom kfac.python.ops.tensormatch import graph_search\n\n# Names for various approximations that can be requested for Fisher blocks.\nAPPROX_KRONECKER_NAME = \"kron\"\nAPPROX_KRONECKER_IN_DIAG_NAME = \"kron_in_diag\"\nAPPROX_KRONECKER_OUT_DIAG_NAME = \"kron_out_diag\"\nAPPROX_KRONECKER_BOTH_DIAG_NAME = \"kron_both_diag\"\nAPPROX_DIAGONAL_NAME = \"diagonal\"\nAPPROX_FULL_NAME = \"full\"\n\nAPPROX_KRONECKER_INDEP_NAME = \"kron_indep\"\nAPPROX_KRONECKER_INDEP_IN_DIAG_NAME = \"kron_indep_in_diag\"\nAPPROX_KRONECKER_INDEP_OUT_DIAG_NAME = \"kron_indep_out_diag\"\nAPPROX_KRONECKER_INDEP_BOTH_DIAG_NAME = \"kron_indep_both_diag\"\nAPPROX_KRONECKER_SERIES_1_NAME = \"kron_series_1\"\nAPPROX_KRONECKER_SERIES_2_NAME = \"kron_series_2\"\nAPPROX_KRONECKER_SUA_NAME = \"kron_sua\"\n\n\n# Possible value for 'reuse' keyword argument. Sets 'reuse' to\n# tf.get_variable_scope().reuse.\nVARIABLE_SCOPE = \"VARIABLE_SCOPE\"\n\n_DEFAULT_LAYER_COLLECTION = None\n\n\ndef get_default_layer_collection():\n  \"\"\"Get default LayerCollection.\"\"\"\n  if _DEFAULT_LAYER_COLLECTION is None:\n    raise ValueError(\n        \"Attempted to retrieve default LayerCollection when none is set. Use \"\n        \"LayerCollection.as_default().\")\n\n  return _DEFAULT_LAYER_COLLECTION\n\n\ndef set_default_layer_collection(layer_collection):\n  global _DEFAULT_LAYER_COLLECTION\n\n  if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:\n    raise ValueError(\"Default LayerCollection is already set.\")\n\n  _DEFAULT_LAYER_COLLECTION = layer_collection\n\n\nclass LayerParametersDict(OrderedDict):\n  \"\"\"An OrderedDict where keys are Tensors or tuples of Tensors.\n\n  Ensures that no Tensor is associated with two different keys.\n  \"\"\"\n\n  def __init__(self, *args, **kwargs):\n    self._tensors = set()\n    super(LayerParametersDict, self).__init__(*args, **kwargs)\n\n  def __setitem__(self, key, value):\n    key = self._canonicalize_key(key)\n    tensors = key if isinstance(key, (tuple, list)) else (key,)\n    key_collisions = self._tensors.intersection(tensors)\n    if key_collisions:\n      raise ValueError(\"Key(s) already present: {}\".format(key_collisions))\n    self._tensors.update(tensors)\n    super(LayerParametersDict, self).__setitem__(key, value)\n\n  def __delitem__(self, key):\n    key = self._canonicalize_key(key)\n    self._tensors.remove(key)\n    super(LayerParametersDict, self).__delitem__(key)\n\n  def __getitem__(self, key):\n    key = self._canonicalize_key(key)\n    return super(LayerParametersDict, self).__getitem__(key)\n\n  def __contains__(self, key):\n    key = self._canonicalize_key(key)\n    return super(LayerParametersDict, self).__contains__(key)\n\n  def _canonicalize_key(self, key):\n    if isinstance(key, (list, tuple)):\n      return tuple(key)\n    return key\n\n\n# TODO(b/68034464): add capability for LayerCollection to be \"finalized\"\n# and do this when it gets used by FisherEstimator / KfacOptimizer.\n\n\nclass LayerCollection(object):\n  \"\"\"Registry of information about layers and losses.\n\n  Note that you need to create a new one of these for each FisherEstimator or\n  KfacOptimizer, as they can't be used more than once.\n\n  The methods that you should interact with directly are:\n   - register_XXX()\n   - auto_register_layers()\n\n  Additional control over the automatic registration process can be exerted by\n  using the methods/properties:\n   - set_default_XXX() and default_XXX\n   - define_linked_parameters() and linked_parameters\n\n\n  Attributes:\n    fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer\n        parameters (Tensors or tuples of Tensors) to FisherBlock instances.\n    fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.\n    losses: a list of LossFunction objects. The loss to be optimized is their\n        sum.\n    loss_colocation_ops: ops to colocate loss function evaluations with.  These\n        will typically be the inputs to the losses.\n  \"\"\"\n\n  def __init__(self,\n               graph=None,\n               name=\"LayerCollection\"):\n    self.fisher_blocks = LayerParametersDict()\n    self.fisher_factors = OrderedDict()\n    self._linked_parameters = dict(\n    )  # dict mapping sets of variables to optionally specified approximations.\n    self._graph = graph or tf.get_default_graph()\n    self._loss_dict = OrderedDict()  # {str: LossFunction}\n    self._subgraph = None\n    self._default_generic_approximation = APPROX_DIAGONAL_NAME\n    self._default_fully_connected_approximation = APPROX_KRONECKER_NAME\n    self._default_conv2d_approximation = APPROX_KRONECKER_NAME\n    self._default_fully_connected_multi_approximation = (\n        APPROX_KRONECKER_INDEP_NAME)\n    self._default_conv2d_multi_approximation = (\n        APPROX_KRONECKER_INDEP_NAME)\n    self._default_scale_and_shift_approximation = APPROX_FULL_NAME\n    self.loss_colocation_ops = {}\n    self.loss_coeffs = {}\n    self._vars_to_uses = defaultdict(lambda: 0)\n\n    self._finalized = False\n\n    with tf.variable_scope(None, default_name=name) as scope:\n      self._var_scope = scope.name\n\n    self._generic_approx_to_block_types = {\n        APPROX_FULL_NAME: fb.NaiveFullFB,\n        APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,\n    }\n\n    self._fully_connected_approx_to_block_types = {\n        APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,\n        APPROX_KRONECKER_IN_DIAG_NAME:\n            partial(fb.FullyConnectedKFACBasicFB,\n                    diagonal_approx_for_input=True),\n        APPROX_KRONECKER_OUT_DIAG_NAME:\n            partial(fb.FullyConnectedKFACBasicFB,\n                    diagonal_approx_for_output=True),\n        APPROX_KRONECKER_BOTH_DIAG_NAME:\n            partial(fb.FullyConnectedKFACBasicFB,\n                    diagonal_approx_for_input=True,\n                    diagonal_approx_for_output=True),\n        APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,\n    }\n\n    self._conv2d_approx_to_block_types = {\n        APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,\n        APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,\n        APPROX_KRONECKER_SUA_NAME: fb.ConvKFCBasicFB,\n    }\n\n    self._fully_connected_multi_approx_to_block_types = {\n        APPROX_KRONECKER_INDEP_NAME:\n            fb.FullyConnectedMultiIndepFB,\n        APPROX_KRONECKER_INDEP_IN_DIAG_NAME:\n            partial(fb.FullyConnectedMultiIndepFB,\n                    diagonal_approx_for_input=True),\n        APPROX_KRONECKER_INDEP_OUT_DIAG_NAME:\n            partial(fb.FullyConnectedMultiIndepFB,\n                    diagonal_approx_for_output=True),\n        APPROX_KRONECKER_INDEP_BOTH_DIAG_NAME:\n            partial(fb.FullyConnectedMultiIndepFB,\n                    diagonal_approx_for_input=True,\n                    diagonal_approx_for_output=True),\n        APPROX_KRONECKER_SERIES_1_NAME:\n            partial(fb.FullyConnectedSeriesFB, option=1),\n        APPROX_KRONECKER_SERIES_2_NAME:\n            partial(fb.FullyConnectedSeriesFB, option=2)\n    }\n\n    self._conv2d_multi_approx_to_block_types = {\n        APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB\n    }\n\n    self._scale_and_shift_approx_to_block_types = {\n        APPROX_FULL_NAME: fb.ScaleAndShiftFullFB,\n        APPROX_DIAGONAL_NAME: fb.ScaleAndShiftDiagonalFB\n    }\n\n  @property\n  def losses(self):\n    \"\"\"Tuple of LossFunction objects registered with this LayerCollection.\"\"\"\n    return nest.flatten(self.towers_by_loss)\n\n  @property\n  def towers_by_loss(self):\n    \"\"\"Tuple across losses of LossFunction objects registered to each tower.\"\"\"\n    return tuple(tuple(lst) for lst in self._loss_dict.values())\n\n  @property\n  def registered_variables(self):\n    \"\"\"A tuple of all of the variables currently registered.\"\"\"\n    tuple_of_tuples = (utils.ensure_sequence(key) for key, block\n                       in six.iteritems(self.fisher_blocks))\n    flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)\n    return flat_tuple\n\n  @property\n  def linked_parameters(self):\n    \"\"\"Groups of parameters with an optionally specified approximation.\n\n    Linked parameters can be added using `define_linked_parameters`.\n    If an approximation is specified, then this approximation will be used\n    when registering a layer with exactly these parameters, unless an\n    approximation is specified when calling the registration function.\n\n    Returns:\n      A `dict` mapping tuples of parameters to an optional string.\n    \"\"\"\n    return self._linked_parameters\n\n  @property\n  def default_generic_approximation(self):\n    return self._default_generic_approximation\n\n  def set_default_generic_approximation(self, value):\n    if value not in self._generic_approx_to_block_types:\n      raise ValueError(\n          \"{} is not a valid approximation for generic variables.\".format(\n              value))\n    self._default_generic_approximation = value\n\n  @property\n  def default_fully_connected_approximation(self):\n    return self._default_fully_connected_approximation\n\n  def set_default_fully_connected_approximation(self, value):\n    if value not in self._fully_connected_approx_to_block_types:\n      raise ValueError(\n          \"{} is not a valid approximation for fully connected layers.\".format(\n              value))\n    self._default_fully_connected_approximation = value\n\n  @property\n  def default_conv2d_approximation(self):\n    return self._default_conv2d_approximation\n\n  def set_default_conv2d_approximation(self, value):\n    if value not in self._conv2d_approx_to_block_types:\n      raise ValueError(\n          \"{} is not a valid approximation for 2d convolutional layers.\".format(\n              value))\n    self._default_conv2d_approximation = value\n\n  @property\n  def default_fully_connected_multi_approximation(self):\n    return self._default_fully_connected_multi_approximation\n\n  def set_default_fully_connected_multi_approximation(self, value):\n    if value not in self._fully_connected_multi_approx_to_block_types:\n      raise ValueError(\"{} is not a valid approximation for a fully-connected \"\n                       \"multi layer.\".format(value))\n    self._default_fully_connected_multi_approximation = value\n\n  @property\n  def default_conv2d_multi_approximation(self):\n    return self._default_conv2d_multi_approximation\n\n  def set_default_conv2d_multi_approximation(self, value):\n    if value not in self._conv2d_multi_approx_to_block_types:\n      raise ValueError(\"{} is not a valid approximation for a conv2d \"\n                       \"multi layer.\".format(value))\n    self._default_conv2d_multi_approximation = value\n\n  @property\n  def default_scale_and_shift_approximation(self):\n    return self._default_scale_and_shift_approximation\n\n  def set_default_scale_and_shift_approximation(self, value):\n    if value not in self._scale_and_shift_approx_to_block_types:\n      raise ValueError(\"{} is not a valid approximation for a scale & shift \"\n                       \"layer.\".format(value))\n    self._default_scale_and_shift_approximation = value\n\n  def auto_register_layers(self, var_list=None, batch_size=None):\n    \"\"\"Registers remaining unregistered layers automatically using a scanner.\n\n    Requires all function / distribution registrations to be performed\n    (manually) first.\n\n    Registrations will be performed using the default approximation mode for\n    each type, as if the scanner were calling the user-level registration\n    functions in this LayerCollection object (which it will be). These\n    defaults can be overridden using the set_default_XXX_approximation methods\n    for types of layers, or using the define_linked_parameters method for\n    specific parameters.\n\n    This function should only be called after any desired manual registrations\n    are performed. For example, if you have a layer which isn't recognized\n    properly by the scanner, or a layer which you want to register differently.\n\n    Note that this function is an experimental convenience feature which won't\n    work for every possible model architecture. Any layers/parameters that\n    whose structure is not recognized will be registered as \"generic\", which\n    is the worst curvature matrix approximation available in the system, and\n    should be avoided if possible.\n\n    See the docstring for register_layers in graph_search.py for more details.\n\n    Args:\n      var_list: A list of variables that the automatic registration should\n        consider. If you have some trainable variables (i.e. those included in\n        tf.trainable_variables()) that you don't want included you need to pass\n        in this list. (Default: tf.trainable_variables()).\n      batch_size: A `int` representing the batch size. Needs to specified if\n        registering generic variables that don't match any layer patterns or\n        if time/uses is folded. If the time/uses dimension is merged with\n        batch then this is used to infer number of uses/time-steps. NOTE: In the\n        replicated context this must be the per-replica batch size, and not\n        the total batch size.\n    \"\"\"\n    if var_list is None:\n      var_list = tf.trainable_variables()\n    graph_search.register_layers(self, var_list, batch_size=batch_size)\n\n  def finalize(self):\n    if not self._finalized:\n      self._create_subgraph()\n      self._finalized = True\n    else:\n      raise ValueError(\"LayerCollection was finalized a second time, which \"\n                       \"indicates an error. Perhaps you used the same \"\n                       \"LayerCollection object in multiple \"\n                       \"optimizers/estimators, which is not allowed.\")\n\n  def _register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):\n    \"\"\"Validates and registers the layer_key associated with the fisher_block.\n\n    Args:\n      layer_key: A variable or tuple of variables. The key to check for in\n          existing registrations and to register if valid.\n      fisher_block: The associated `FisherBlock`.\n      reuse: Method to use for inserting new `FisherBlock`s. One of True, False,\n        or 'VARIABLE_SCOPE'.\n\n    Raises:\n      ValueError: If `layer_key` was already registered and reuse is `False`,\n        if `layer_key` was registered with a different block type, or if\n        `layer_key` shares any variables with but is not equal to a previously\n        registered key.\n      KeyError: If `reuse` is `True` but `layer_key` was not previously\n        registered.\n\n    Returns:\n      The `FisherBlock` registered under `layer_key`. If `layer_key` was already\n      registered, this will be the previously registered `FisherBlock`.\n    \"\"\"\n    if self._finalized:\n      raise ValueError(\"You cannot register additional losses or layers after \"\n                       \"LayerCollection is finalized. Finalization happens \"\n                       \"after the estimator or optimizer object first uses \"\n                       \"the data in the LayerCollection. For example, when \"\n                       \"the minimize() method is called in \"\n                       \"PeriodicInvCovUpdateKfacOpt.\")\n\n    if reuse is VARIABLE_SCOPE:\n      reuse = tf.get_variable_scope().reuse\n\n    if reuse is True or (reuse is tf.AUTO_REUSE and\n                         layer_key in self.fisher_blocks):\n\n      if layer_key not in self.fisher_blocks:\n        raise ValueError(\n            \"reuse was True for attempted registration involving variables {}, \"\n            \"but no previously registered layer was found for these. Perhaps \"\n            \"reuse was set to True by mistake. One way this can happen is if \"\n            \"reuse is set to True in the surrounding variable scope.\"\n            \"\".format(layer_key))\n\n      result = self.fisher_blocks[layer_key]\n\n      if type(result) != type(fisher_block):  # pylint: disable=unidiomatic-typecheck\n        raise ValueError(\n            \"Attempted to register FisherBlock of type %s when existing \"\n            \"FisherBlock has type %s.\" % (type(fisher_block), type(result)))\n      return result\n    if reuse is False and layer_key in self.fisher_blocks:\n      raise ValueError(\"FisherBlock for %s is already in LayerCollection.\" %\n                       (layer_key,))\n\n    # Insert fisher_block into self.fisher_blocks.\n    if layer_key in self.fisher_blocks:\n      raise ValueError(\"Duplicate registration: {}\".format(layer_key))\n    # Raise an error if any variable in layer_key has been registered in any\n    # other blocks.\n    variable_to_block = {\n        var: (params, block)\n        for (params, block) in self.fisher_blocks.items()\n        for var in utils.ensure_sequence(params)\n    }\n    for variable in utils.ensure_sequence(layer_key):\n      if variable in variable_to_block:\n        prev_key, prev_block = variable_to_block[variable]\n        raise ValueError(\n            \"Attempted to register layer_key {} with block {}, but variable {}\"\n            \" was already registered in key {} with block {}.\".format(\n                layer_key, fisher_block, variable, prev_key, prev_block))\n    self.fisher_blocks[layer_key] = fisher_block\n    return fisher_block\n\n  def _register_loss_function(self,\n                              loss,\n                              colocation_op,\n                              base_name,\n                              name=None,\n                              coeff=1.0,\n                              reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a LossFunction object.\n\n    Args:\n      loss: The LossFunction object.\n      colocation_op: The op to colocate the loss function's computations with.\n      base_name: The name to derive a new unique name from is the name argument\n        is None.\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a scalar. coefficient on the loss function\n        (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, adds 'loss' as an additional\n        tower for the existing loss function.\n\n    Raises:\n      ValueError: If reuse == True and name == None.\n      ValueError: If reuse == True and seed != None.\n      KeyError: If reuse == True and no existing LossFunction with 'name' found.\n      KeyError: If reuse == False and existing LossFunction with 'name' found.\n    \"\"\"\n\n    if self._finalized:\n      raise ValueError(\"You cannot register additional losses or layers after \"\n                       \"LayerCollection is finalized. Finalization happens \"\n                       \"after the estimator or optimizer object first uses \"\n                       \"the data in the LayerCollection. For example, when \"\n                       \"the minimize() method is called in \"\n                       \"PeriodicInvCovUpdateKfacOpt.\")\n\n    name = name or self._graph.unique_name(base_name)\n\n    if reuse == VARIABLE_SCOPE:\n      reuse = tf.get_variable_scope().reuse\n\n    if reuse:\n      if name is None:\n        raise ValueError(\n            \"If reuse is enabled, loss function's name must be set.\")\n\n      loss_list = self._loss_dict.get(name, None)\n\n      if loss_list is None:\n        raise KeyError(\n            \"Unable to find loss function named {}. Register a new loss \"\n            \"function with reuse=False.\".format(name))\n\n      if self.loss_coeffs[loss_list[0]] != coeff:\n        raise ValueError(\n            \"Reused loss function's coeff didn't match previous supplied \"\n            \"value.\")\n\n    else:\n      if name in self._loss_dict:\n        raise KeyError(\n            \"Loss function named {} already exists. Set reuse=True to append \"\n            \"another tower.\".format(name))\n\n      loss_list = []\n      self._loss_dict[name] = loss_list\n\n    loss_list.append(loss)\n    self.loss_colocation_ops[loss] = colocation_op\n    self.loss_coeffs[loss] = coeff\n\n  def _get_use_count_map(self):\n    \"\"\"Returns a dict mapping variables to their number of registrations.\"\"\"\n    return self._vars_to_uses\n\n  def _add_uses(self, params, uses):\n    \"\"\"Register additional uses by params in the graph.\n\n    Args:\n      params: Variable or tuple of Variables. Parameters for a layer.\n      uses: int or float. Number of additional uses for these parameters.\n    \"\"\"\n    params = params if isinstance(params, (tuple, list)) else (params,)\n    for var in params:\n      self._vars_to_uses[var] += uses\n\n  def check_registration(self, variables):\n    \"\"\"Checks that all variable uses have been registered properly.\n\n    Args:\n      variables: List of variables.\n\n    Raises:\n      ValueError: If any registered variables are not included in the list.\n      ValueError: If any variable in the list is not registered.\n      ValueError: If any variable in the list is registered with the wrong\n          number of \"uses\" in the subgraph recorded (vs the number of times that\n          variable is actually used in the subgraph).\n    \"\"\"\n    # Note that overlapping parameters (i.e. those that share variables) will\n    # be caught by layer_collection.LayerParametersDict during registration.\n\n    reg_use_map = self._get_use_count_map()\n\n    error_messages = []\n\n    for var in variables:\n      total_uses = self.subgraph.variable_uses(var)\n      reg_uses = reg_use_map[var]\n\n      if reg_uses == 0:\n        error_messages.append(\"Variable {} not registered.\".format(var))\n      elif (not math.isinf(reg_uses)) and reg_uses != total_uses:\n        error_messages.append(\n            \"Variable {} registered with wrong number of uses ({} uses \"\n            \"registered vs {} uses found in sub-graph generated from \"\n            \"registered losses).\".format(var, reg_uses, total_uses))\n\n    num_get_vars = len(reg_use_map)\n\n    if num_get_vars > len(variables):\n      error_messages.append(\"{} registered variables were not included in list.\"\n                            .format(num_get_vars - len(variables)))\n\n    if error_messages:\n      error_string = \"\\n\\t\".join([\n          \"Found the following errors with variable registration:\"\n      ] + error_messages)\n      raise ValueError(error_string)\n\n  def get_blocks(self):\n    return tuple(self.fisher_blocks.values())\n\n  def get_factors(self):\n    return tuple(self.fisher_factors.values())\n\n  @property\n  def graph(self):\n    return self._graph\n\n  @property\n  def subgraph(self):\n    return self._subgraph\n\n  def define_linked_parameters(self, params, approximation=None):\n    \"\"\"Identify a set of parameters that should be grouped together.\n\n    Also allows the approximation type string to be set for the given\n    parameter grouping.\n\n    During automatic graph scanning (as done by the auto_register_layers method)\n    any matches containing variables that have been identified as part of a\n    linked group will be filtered out unless the match parameters are exactly\n    equal to the ones specified in the linked group.\n\n    Args:\n      params: A variable, or a tuple or list of variables. The variables\n        to be linked.\n      approximation: Optional string specifying the type of approximation to use\n        for these variables. If unspecified, this layer collection's default\n        approximation for the layer type will be used.\n\n    Raises:\n      ValueError: If the parameters were already registered in a layer or\n        identified as part of an incompatible group.\n    \"\"\"\n    params = frozenset(utils.ensure_sequence(params))\n\n    # Check if any of the variables in 'params' is already in\n    # 'self.fisher_blocks.keys()'.\n    for registered_params, fisher_block in self.fisher_blocks.items():\n      registered_params_set = set(utils.ensure_sequence(registered_params))\n      for variable in params:\n        if (variable in registered_params_set and\n            params != registered_params_set):\n          raise ValueError(\n              \"Can't link parameters {}, variable {} was already registered in \"\n              \"group {} with layer {}\".format(params, variable,\n                                              registered_params, fisher_block))\n\n    # Check if any of the variables in 'params' is already in\n    # 'self.linked_parameters'.\n    for variable in params:\n      for other_linked_params in self.linked_parameters:\n        if variable in other_linked_params:\n          raise ValueError(\"Can't link parameters {}, variable {} was already \"\n                           \"linked in group {}.\".format(params, variable,\n                                                        other_linked_params))\n    self._linked_parameters[params] = approximation\n\n  def _create_subgraph(self):\n    if not self.losses:\n      raise ValueError(\"Must have at least one registered loss.\")\n    inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))\n    self._subgraph = utils.SubGraph(inputs_to_losses)\n\n  def eval_losses(self, target_mode=\"data\", coeff_mode=\"regular\"):\n    \"\"\"Returns evaluated losses (colocated with inputs to losses).\"\"\"\n    evals = []\n    for loss in self.losses:\n      with tf.colocate_with(self.loss_colocation_ops[loss]):\n        if target_mode == \"data\":\n          loss_value = loss.evaluate()\n        elif target_mode == \"sample\":\n          loss_value = loss.evaluate_on_sample()\n        else:\n          raise ValueError(\"target_mode must be in ['data', 'sample']\")\n\n        if coeff_mode == \"regular\":\n          multiplier = self.loss_coeffs[loss]\n        elif coeff_mode == \"sqrt\":\n          multiplier = tf.sqrt(self.loss_coeffs[loss])\n        elif coeff_mode == \"off\":\n          multiplier = 1.0\n        else:\n          raise ValueError(\"coeff_mode must be in ['regular', 'sqrt', 'off']\")\n        multiplier = tf.cast(multiplier, dtype=loss_value.dtype)\n        evals.append(multiplier * loss_value)\n    return evals\n\n  def total_loss(self, coeff_mode=\"regular\"):\n    return tf.add_n(self.eval_losses(target_mode=\"data\",\n                                     coeff_mode=coeff_mode))\n\n  def total_sampled_loss(self, coeff_mode=\"regular\"):\n    return tf.add_n(self.eval_losses(target_mode=\"sample\",\n                                     coeff_mode=coeff_mode))\n\n  def _get_linked_approx(self, params):\n    \"\"\"If params were linked, return their specified approximation.\"\"\"\n    params_set = frozenset(utils.ensure_sequence(params))\n    if params_set in self.linked_parameters:\n      return self.linked_parameters[params_set]\n    else:\n      return None\n\n  def _get_block_type(self, params, approx, default, approx_to_type):\n    if approx is None:\n      approx = self._get_linked_approx(params)\n      if approx is None:\n        approx = default\n\n    if approx not in approx_to_type:\n      raise ValueError(\"Bad value {} for approx.\".format(approx))\n\n    return approx_to_type[approx], approx\n\n  def register_fully_connected(self,\n                               params,\n                               inputs,\n                               outputs,\n                               approx=None,\n                               dense_inputs=True,\n                               reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a fully connected layer.\n\n    Args:\n      params: Variable or 2-tuple of variables corresponding to weight and\n        bias parameters of this layer. Weight matrix should have shape\n        [input_size, output_size]. Bias should have shape [output_size].\n      inputs: Tensor. Two formats are accepted. In most cases the Tensor is\n        dense inputs, with shape [batch_size, input_size]. In some cases\n        the Tensor is sparse inputs, with shape [batch_size]. A typical example\n        of sparse inputs is the vocab indices into an embedding matrix. Sparse\n        inputs will be converted to the dense format within KFAC. For sparse\n        inputs, dense_inputs should be set to False.\n      outputs: Tensor of shape [batch_size, output_size]. Outputs\n        produced by layer.\n      approx: str or None. If not None must be one of \"kron\", \"kron_in_diag\"\n        (diagonal approximation for the input kronecker factor), \"kron_out_diag\"\n        (diagonal approximation for the output kronecker factor),\n        \"kron_both_diag\" or \"diagonal\". The Fisher approximation to use. If\n        None the default value is used. (Default: None)\n      dense_inputs: bool. True if inputs are dense inputs. (Default: True)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n\n    block_type, approx = self._get_block_type(\n        params, approx, self.default_fully_connected_approximation,\n        self._fully_connected_approx_to_block_types)\n\n    has_bias = isinstance(params, (tuple, list))\n    block = self._register_block(\n        params, block_type(self, has_bias=has_bias), reuse=reuse)\n\n    if not dense_inputs:\n      inputs.one_hot_depth = int(params.shape[0])\n    block.register_additional_tower(inputs, outputs)\n\n    self._add_uses(params, 1)\n\n  def register_conv1d(self,\n                      params,\n                      strides,\n                      padding,\n                      inputs,\n                      outputs,\n                      dilations=None,\n                      approx=None,\n                      reuse=VARIABLE_SCOPE,\n                      sub_sample_inputs=None,\n                      sub_sample_patches=None):\n    \"\"\"Registers a call to tf.nn.conv1d().\n\n    Args:\n      params: Variablle or 2-tuple of variables corresponding to weight and\n        bias parameters this layer. Weight matrix should have shape\n        [kernel_width, in_channels, out_channels].  Bias should have shape\n        [out_channels].\n      strides: List of 3 ints. Strides for convolution kernel.\n      padding: string. see tf.nn.conv2d for valid values.\n      inputs: Tensor of shape [batch_size, width, in_channels]. Inputs\n        to layer.\n      outputs: Tensor of shape [batch_size, width, out_channels].\n        Output produced by layer.\n      dilations: List of 3 ints. Dilations along each dimension.\n      approx: str or None. If not None, must be \"kron\". The Fisher approximation\n        to use. If None, the default value is used. (Default: None)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n      sub_sample_inputs: `bool`. If True, then subsample the inputs from which\n        the image patches are extracted. (Default: None)\n      sub_sample_patches: `bool`, If `True` then subsample the extracted\n        patches. (Default: None)\n\n    Raises:\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    assert approx is None or approx == APPROX_KRONECKER_NAME\n\n    block = self._register_block(\n        params,\n        fb.ConvKFCBasicFB(\n            layer_collection=self,\n            params=params,\n            padding=padding,\n            strides=strides,\n            data_format=\"NWC\",\n            dilation_rate=dilations,\n            extract_patches_fn=\"extract_convolution_patches\",\n            sub_sample_inputs=sub_sample_inputs,\n            sub_sample_patches=sub_sample_patches,\n            use_sua_approx_for_input_factor=False),\n        reuse=reuse)\n    block.register_additional_tower(inputs, outputs)\n\n    self._add_uses(params, 1)\n\n  def register_conv2d(self,\n                      params,\n                      strides,\n                      padding,\n                      inputs,\n                      outputs,\n                      data_format=None,\n                      dilations=None,\n                      approx=None,\n                      reuse=VARIABLE_SCOPE,\n                      sub_sample_inputs=None,\n                      sub_sample_patches=None,\n                      patch_mask=None):\n    \"\"\"Registers a call to tf.nn.conv2d().\n\n    Args:\n      params: Variable or 2-tuple of variables corresponding to weight and\n        bias parameters of this layer. Weight matrix should have shape\n        [kernel_height, kernel_width, in_channels, out_channels].  Bias should\n        have shape [out_channels].\n      strides: List of 4 ints. Strides for convolution kernel.\n      padding: string. see tf.nn.conv2d for valid values.\n      inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs\n        to layer.\n      outputs: Tensor of shape [batch_size, height, width, out_channels].\n        Output produced by layer.\n      data_format: str or None. Format of data. If None, this should default\n        to 'NWHC'. (Default: None)\n      dilations: List of 4 ints. Dilations along each dimension.\n      approx: str or None. If not None must be one of \"kron\" or \"diagonal\".\n        The Fisher approximation to use. If None the default value is used.\n        (Default: None)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n      sub_sample_inputs: `bool`. If True, then subsample the inputs from which\n        the image patches are extracted. (Default: None)\n      sub_sample_patches: `bool`, If `True` then subsample the extracted\n        patches. (Default: None)\n      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]\n        or None. If not None this is multiplied against the extracted patches\n        Tensor (broadcasting along the batch dimension) before statistics are\n        computed. This can (and probably should) be used if the filter bank\n        matrix is masked in a way that is homogenous across the output channels.\n        (Other masking patterns have no direct support.) Currently only works\n        with the approx=\"kron\" or \"diagonal\". (Default: None)\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    assert data_format in [None, \"NHWC\"]  # We don't support NCHW right now\n\n    block_type, approx = self._get_block_type(\n        params, approx, self.default_conv2d_approximation,\n        self._conv2d_approx_to_block_types)\n\n    # It feels bad to pass in configuration that has to do with the internal\n    # implementation.  And then we can't use the same constructor for both\n    # anymore and are thus forced to use this ugly if-statement.\n    # TODO(b/74793309): Clean this up?\n    if approx == APPROX_KRONECKER_NAME:\n      block = self._register_block(\n          params,\n          block_type(\n              layer_collection=self,\n              params=params,\n              padding=padding,\n              strides=strides,\n              data_format=data_format,\n              dilation_rate=dilations,\n              extract_patches_fn=\"extract_image_patches\",\n              sub_sample_inputs=sub_sample_inputs,\n              sub_sample_patches=sub_sample_patches,\n              use_sua_approx_for_input_factor=False,\n              patch_mask=patch_mask),\n          reuse=reuse)\n    elif approx == APPROX_DIAGONAL_NAME:\n      assert strides[0] == strides[-1] == 1\n      block = self._register_block(\n          params,\n          block_type(\n              layer_collection=self,\n              params=params,\n              padding=padding,\n              strides=strides,\n              dilations=dilations,\n              data_format=data_format,\n              patch_mask=patch_mask),\n          reuse=reuse)\n    elif approx == APPROX_KRONECKER_SUA_NAME:\n      block = self._register_block(\n          params,\n          block_type(\n              layer_collection=self,\n              params=params,\n              padding=padding,\n              use_sua_approx_for_input_factor=True),\n          reuse=reuse)\n\n    else:\n      raise NotImplementedError(approx)\n\n    block.register_additional_tower(inputs, outputs)\n\n    self._add_uses(params, 1)\n\n  def register_convolution(self,\n                           params,\n                           inputs,\n                           outputs,\n                           padding,\n                           strides=None,\n                           dilation_rate=None,\n                           data_format=None,\n                           approx=None,\n                           reuse=VARIABLE_SCOPE):\n    \"\"\"Register a call to tf.nn.convolution().\n\n    Unless you know what you are doing you should be using register_conv2d\n    instead.\n\n    Args:\n      params: Variable or 2-tuple of variables corresponding to weight and\n        bias parameters of this layer. Weight matrix should have shape\n        [..filter_spatial_size.., in_channels, out_channels].  Bias should have\n        shape [out_channels].\n      inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].\n        Inputs to layer.\n      outputs: Tensor of shape [batch_size, ..output_spatial_size..,\n        out_channels].  Output produced by layer.\n      padding: string. see tf.nn.conv2d for valid values.\n      strides: List of ints of length len(..input_spatial_size..). Strides for\n        convolution kernel in spatial dimensions.\n      dilation_rate: List of ints of length len(..input_spatial_size..).\n        Dilations along spatial dimension.\n      data_format: str or None. Format of data.\n      approx: str or None. If not None, must be \"kron\". The Fisher approximation\n        to use. If None, the default value is used. (Default: None)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    # TODO(b/74793309): Have this use _get_block_type like the other\n    # registration functions?\n    assert approx is None or approx == APPROX_KRONECKER_NAME\n\n    block = self._register_block(\n        params,\n        fb.ConvKFCBasicFB(\n            layer_collection=self,\n            params=params,\n            padding=padding,\n            strides=strides,\n            dilation_rate=dilation_rate,\n            data_format=data_format),\n        reuse=reuse)\n    block.register_additional_tower(inputs, outputs)\n\n    self._add_uses(params, 1)\n\n  def register_depthwise_conv2d(self,\n                                params,\n                                inputs,\n                                outputs,\n                                strides,\n                                padding,\n                                rate=None,\n                                data_format=None,\n                                approx=None,\n                                reuse=VARIABLE_SCOPE):\n    \"\"\"Register a call to tf.nn.depthwise_conv2d().\n\n    Note that this is an experimental feature that hasn't been experimentally\n    validated or published on.\n\n    Args:\n      params: 4-D variable of shape [filter_height, filter_width, in_channels,\n        channel_multiplier].  Convolutional filter.\n      inputs: Tensor of shape [batch_size, input_height, input_width,\n        in_channels].  Inputs to layer.\n      outputs: Tensor of shape [batch_size, output_height, output_width,\n        in_channels * channel_multiplier].  Output produced by depthwise conv2d.\n      strides: List of ints of length 4. Strides along all dimensions.\n      padding: string. see tf.nn.conv2d for valid values.\n      rate: None or List of ints of length 2. Dilation rates in spatial\n        dimensions.\n      data_format: str or None. Format of data.\n      approx: str or None. If not None must \"diagonal\".  The Fisher\n        approximation to use. If None the default value is used. (Default: None)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    # TODO(b/74793309): Have this use _get_block_type like the other\n    # registration functions?\n    assert approx is None or approx == APPROX_DIAGONAL_NAME\n    assert data_format in [None, \"NHWC\"]\n\n    block = self._register_block(\n        params,\n        fb.DepthwiseConvDiagonalFB(\n            layer_collection=self,\n            params=params,\n            strides=strides,\n            padding=padding,\n            rate=rate,\n            data_format=data_format),\n        reuse=reuse)\n    block.register_additional_tower(inputs, outputs)\n\n    self._add_uses(params, 1)\n\n  def register_separable_conv2d(self,\n                                depthwise_params,\n                                pointwise_params,\n                                inputs,\n                                depthwise_outputs,\n                                pointwise_outputs,\n                                strides,\n                                padding,\n                                rate=None,\n                                data_format=None,\n                                approx=None,\n                                reuse=VARIABLE_SCOPE):\n    \"\"\"Register a call to tf.nn.separable_conv2d().\n\n    Note: This requires access to intermediate outputs between depthwise and\n    pointwise convolutions.\n\n    Note that this is an experimental feature that hasn't been experimentally\n    validated or published on.\n\n    Args:\n      depthwise_params: 4-D variable of shape [filter_height, filter_width,\n        in_channels, channel_multiplier].  Filter for depthwise conv2d.\n      pointwise_params: 4-D variable of shape [1, 1, in_channels *\n        channel_multiplier, out_channels].  Filter for pointwise conv2d.\n      inputs: Tensor of shape [batch_size, input_height, input_width,\n        in_channels].  Inputs to layer.\n      depthwise_outputs: Tensor of shape [batch_size, output_height,\n        output_width, in_channels * channel_multiplier].  Output produced by\n        depthwise conv2d.\n      pointwise_outputs: Tensor of shape [batch_size, output_height,\n        output_width, out_channels].  Output produced by pointwise conv2d.\n      strides: List of ints of length 4. Strides for depthwise conv2d kernel in\n        all dimensions.\n      padding: string. see tf.nn.conv2d for valid values.\n      rate: None or List of ints of length 2. Dilation rate of depthwise conv2d\n        kernel in spatial dimensions.\n      data_format: str or None. Format of data.\n      approx: str or None. If not None must be one of \"kron\" or \"diagonal\".\n        The Fisher approximation to use. If None the default value is used.\n        (Default: None)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    self.register_depthwise_conv2d(\n        params=depthwise_params,\n        inputs=inputs,\n        outputs=depthwise_outputs,\n        strides=strides,\n        padding=padding,\n        rate=rate,\n        data_format=data_format,\n        approx=APPROX_DIAGONAL_NAME,\n        reuse=reuse)\n\n    self.register_conv2d(\n        params=pointwise_params,\n        inputs=depthwise_outputs,\n        outputs=pointwise_outputs,\n        strides=[1, 1, 1, 1],\n        padding=\"VALID\",\n        data_format=data_format,\n        approx=approx,\n        reuse=reuse)\n\n  def register_generic(self,\n                       params,\n                       batch_size,\n                       approx=None,\n                       reuse=VARIABLE_SCOPE):\n    \"\"\"Registers parameters without assuming any structure.\n\n    Note that this is an approximation of last resort and should be avoided if\n    anything else will work.\n\n    Args:\n      params: Variable or tuple of variables corresponding to the parameters.\n        If using \"diagonal\" approximation this must be a single variable.\n      batch_size: 0-D Tensor. Size of the minibatch (for this tower).\n      approx: str or None. It not None, must be one of \"full\" or \"diagonal\".\n        The Fisher approximation to use. If None the default value is used.\n        (Default: None)\n      reuse: bool or str. If True, this adds 'batch_size' to the total\n        mini-batch size use when estimating the Fisher block for this layer\n        (which must have already been registered). If \"VARIABLE_SCOPE\", use\n        tf.get_variable_scope().reuse. (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n      ValueError: If approx == \"diagonal\" and params is a tuple.\n    \"\"\"\n    block_type, approx = self._get_block_type(\n        params, approx, self.default_generic_approximation,\n        self._generic_approx_to_block_types)\n\n    if approx == APPROX_DIAGONAL_NAME and isinstance(params, (tuple, list)):\n      raise ValueError(\"Params must be a Variable if using the diagonal \"\n                       \"approximation.\")\n\n    block = self._register_block(params, block_type(self, params), reuse=reuse)\n    block.register_additional_tower(batch_size)\n\n    self._add_uses(params, float(\"inf\"))\n\n  def register_fully_connected_multi(self, params, inputs, outputs,\n                                     num_uses=None, approx=None,\n                                     dense_inputs=True, reuse=VARIABLE_SCOPE):\n    \"\"\"Register fully connected layers with shared parameters.\n\n    This can handle general fully-connected layers with shared parameters, but\n    has specialized approximations to deal with the case where there is a\n    meaningful linear order to the share instances (such as in an RNN).\n\n    Note that padding is *not* supported. The arguments to this method cannot\n    be zero-padded or anything of that sort.\n\n    Args:\n      params: Variable or 2-tuple of variables corresponding to weight and\n        bias of this layer. Weight matrix should have shape [input_size,\n        output_size]. Bias should have shape [output_size].\n      inputs: A list of Tensors or a single Tensor. Inputs to this layer. If a\n        list of Tensors, the list indexes each use in the model (which might\n        correspond to a \"time-step\" in an RNN). Each Tensor in the list has\n        leading dimension batch_size. If a single Tensor, should have shape\n        [num_uses, batch_size, input_size] or be a reshape of such a tensor\n        to shape [num_uses, batch_size, input_size]. Similar to\n        register_fully_connected(), two formats of tensors are accepted: dense\n        inputs and sparse inputs. In most cases the Tensors are dense inputs,\n        with shape [batch_size, input_size] (if a list) or\n        [num_uses, batch_size, input_size] (if a single Tensor) or\n        [num_uses * batch_size, input_size] (if a single Tensor). In some cases\n        the Tensors are sparse inputs, with shape [batch_size] (if a list) or\n        or [num_uses, batch_size] (if a single Tensor) or\n        [num_uses * batch_size] (if a single Tensor). A typical example of\n        sparse inputs is the vocab indices into an embedding matrix. For sparse\n        inputs, the argument 'dense_inputs' should be set to False.\n      outputs: A list of Tensors, the same length as 'inputs', each of shape\n        [batch_size, output_size]. Outputs produced by layer. The list indexes\n        each use in the model (which might correspond to a \"time-step\" in an\n        RNN). Needs to correspond with the order used in 'inputs'.  OR, can be\n        a single Tensor of shape [num_uses * batch_size, output_size], which is\n        a reshaped version of a Tensor of shape [num_uses, batch_size,\n        output_size].\n      num_uses: int or None. The number uses/time-steps in the model where the\n        layer appears. Only needed if both inputs and outputs are given in the\n        single Tensor format. (Default: None)\n      approx: str or None. If not None, must be one of \"kron_indep\",\n        \"kron_indep_in_diag\" (diagonal approximation for the input kronecker\n        factor), \"kron_indep_out_diag\" (diagonal approximation for the output\n        kronecker factor), \"kron_indep_both_diag\", \"kron_series_1\" or\n        \"kron_series_2\". The Fisher approximation to use. If None the default\n        value is used (which starts out as \"kron_indep\"). (Default: None)\n      dense_inputs: bool. True if inputs are dense inputs. (Default: True)\n      reuse: bool or str.  If True, this adds inputs and outputs as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.  (Note that the\n        word 'use' here has a completely different meaning to \"use in the model\"\n        as it pertains to the 'inputs', 'outputs', and 'num_uses' arguments.)\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n    \"\"\"\n    block_type, approx = self._get_block_type(\n        params, approx, self.default_fully_connected_multi_approximation,\n        self._fully_connected_multi_approx_to_block_types)\n\n    # TODO(b/70283649): something along the lines of find_canonical_output\n    # should be added back in here (and for the other block types, arguably).\n    has_bias = isinstance(params, (tuple, list))\n    block = self._register_block(\n        params,\n        block_type(self, has_bias=has_bias, num_uses=num_uses),\n        reuse=reuse)\n\n    if isinstance(inputs, (tuple, list)):\n      inputs = tuple(inputs)\n    if isinstance(outputs, (tuple, list)):\n      outputs = tuple(outputs)\n\n    if not dense_inputs:\n      if isinstance(inputs, (tuple, list)):\n        for input in inputs:\n          input.one_hot_depth = int(params.shape[0])\n      else:\n        inputs.one_hot_depth = int(params.shape[0])\n\n    block.register_additional_tower(inputs, outputs)\n    if isinstance(inputs, (tuple, list)):\n      assert len(inputs) == len(outputs)\n      self._add_uses(params, len(inputs))\n    else:\n      self._add_uses(params, 1)\n\n  def register_conv2d_multi(self,\n                            params,\n                            strides,\n                            padding,\n                            inputs,\n                            outputs,\n                            num_uses=None,\n                            data_format=None,\n                            dilations=None,\n                            approx=None,\n                            reuse=VARIABLE_SCOPE):\n    \"\"\"Registers convolutional layers with shared parameters.\n\n    Note that padding is *not* supported. The arguments to this method cannot\n    be zero-padded or anything of that sort.\n\n    Args:\n      params: Variable or 2-tuple of variables corresponding to weight and\n        bias of this layer. Weight matrix should have shape [kernel_height,\n        kernel_width, in_channels, out_channels].  Bias should have shape\n        [out_channels].\n      strides: 1-D Tensor of length 4. Strides for convolution kernel.\n      padding: string. see tf.nn.conv2d for valid values.\n      inputs: A list of Tensors, each of shape [batch_size, height, width,\n        in_channels]. Inputs to layer. The list indexes each use in the model\n        (which might correspond to a \"time-step\" in an RNN). OR, can be single\n        Tensor, of shape [num_uses * batch_size, height, width, in_channels],\n        which is a reshaped version of a Tensor of shape [num_uses, batch_size,\n        height, width, in_channels].\n      outputs: A list of Tensors, each of shape [batch_size, height, width,\n        out_channels]. Output produced by layer. The list indexes each use\n        in the model (which might correspond to a \"time-step\" in an RNN).\n        Needs to correspond with the order used in 'inputs'.  OR, can be a\n        single Tensor, of shape [num_uses * batch_size, height, width,\n        out_channels], which is a reshaped version of a Tensor of shape\n        [num_uses, batch_size, height, width, out_channels].\n      num_uses: int or None. The number uses/time-steps in the model where the\n        layer appears. Only needed if both inputs and outputs are given in the\n        single Tensor format. (Default: None)\n      data_format: str or None. Format of data.\n      dilations: List of 4 ints. Dilations along each dimension.\n      approx: str or None. If not None must be \"kron_indep\". The Fisher\n        approximation to use. If None the default value is used (which starts\n        out as \"kron_indep\"). (Default: None)\n      reuse: bool or str.  If True, this adds inputs and outputs as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.  (Note that the\n        word 'use' here has a completely different meaning to \"use in the model\"\n        as it pertains to the 'inputs', 'outputs', and 'num_uses' arguments.)\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    assert data_format in [None, \"NHWC\"]  # We don't support NCHW right now\n\n    block_type, approx = self._get_block_type(\n        params, approx, self.default_conv2d_multi_approximation,\n        self._conv2d_multi_approx_to_block_types)\n\n    block = self._register_block(\n        params,\n        block_type(\n            layer_collection=self,\n            params=params,\n            padding=padding,\n            strides=strides,\n            data_format=data_format,\n            dilation_rate=dilations,\n            extract_patches_fn=\"extract_image_patches\",\n            num_uses=num_uses),\n        reuse=reuse)\n\n    if isinstance(inputs, (tuple, list)):\n      inputs = tuple(inputs)\n    if isinstance(outputs, (tuple, list)):\n      outputs = tuple(outputs)\n\n    block.register_additional_tower(inputs, outputs)\n    if isinstance(inputs, (tuple, list)):\n      assert len(inputs) == len(outputs)\n      self._add_uses(params, len(inputs))\n    else:\n      self._add_uses(params, 1)\n\n  def register_scale_and_shift(self,\n                               params,\n                               inputs,\n                               outputs,\n                               approx=None,\n                               reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a scale and shift operation.\n\n    A scale and shift operation is a parameterized operation of the form\n\n    outputs = scale * inputs + shift ,\n\n    where scale and shift are variables that broadcast to the shape of inputs.\n\n    outputs and inputs must have batch dimension. scale and shift can have\n    a corresponding dimension (although they don't need to), but it must\n    be 1.\n\n    These kinds of operations appear frequently in various \"normalization\"\n    layers like Layer Normalization. Batch Normalization layers should still\n    be registered as \"generic\".\n\n    Note that this is an experimental feature that hasn't been experimentally\n    validated or published on.\n\n    Args:\n      params: Variable or 2-tuple of Variables corresponding to the scale\n        and possibly shift parameters (scale must be first).  Note that if\n        these have a dimension corresponding to the batch dimension of 'inputs'\n        and 'outputs', that dimension must be 1.\n      inputs: Tensor of shape [batch_size, ...]. Input tensor that is multiplied\n        by the scale the scale tensor.\n      outputs: Tensor of shape [batch_size, ...]. Final output produced by the\n        scale and shift. Must have the same shape as 'inputs'.\n      approx: str or None. If not None must be one of \"full\" or \"diagonal\".\n        The Fisher approximation to use. If None the default value is used.\n        (Default: None)\n      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an\n        additional mini-batch/tower of data to use when estimating the Fisher\n        block for this layer (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n\n    Raises:\n      ValueError: For improper value to 'approx'.\n      KeyError: If reuse == True but no FisherBlock found for 'params'.\n      ValueError: If reuse == True and FisherBlock found but of the wrong type.\n    \"\"\"\n    # TODO(jamesmartens): Consider replacing some of the logic below with calls\n    # to tf.broadcast_static_shape.\n    if isinstance(params, (tuple, list)):\n      scale = params[0]\n      shift = params[1]\n\n      has_shift = True\n\n      start_dim = len(outputs.shape) - len(shift.shape)\n      if start_dim < 0:\n        raise ValueError(\"Rank of shift cannot exceed that of outputs.\")\n      if start_dim == 0 and shift.shape[0] != 1:\n        raise ValueError(\"If shift has a batch dimension its value must be 1.\")\n      broadcast_dims_shift = list(range(1, start_dim))\n      for i in range(max(start_dim, 1), len(outputs.shape)):\n        if shift.shape[i - start_dim] < outputs.shape[i]:\n          if shift.shape[i - start_dim] == 1:\n            broadcast_dims_shift.append(i)\n          else:\n            raise ValueError(\"It appears that shift param and output have \"\n                             \"incompatible shapes. This is probably due to \"\n                             \"misspecified arguments.\")\n        elif shift.shape[i - start_dim] > outputs.shape[i]:\n          raise ValueError(\"It appears that shift param and output have \"\n                           \"incompatible shapes. This is probably due to \"\n                           \"misspecified arguments.\")\n      broadcast_dims_shift = tuple(broadcast_dims_shift)\n    else:\n      has_shift = False\n      scale = params\n      broadcast_dims_shift = None\n\n    start_dim = len(inputs.shape) - len(scale.shape)\n    if start_dim < 0:\n      raise ValueError(\"Rank of scale cannot exceed that of inputs.\")\n    if start_dim == 0 and scale.shape[0] != 1:\n      raise ValueError(\"If scale has a batch dimension its value must be 1.\")\n    broadcast_dims_scale = list(range(1, start_dim))\n    for i in range(max(start_dim, 1), len(inputs.shape)):\n      if scale.shape[i - start_dim] < inputs.shape[i]:\n        if scale.shape[i - start_dim] == 1:\n          broadcast_dims_scale.append(i)\n        else:\n          raise ValueError(\"It appears that scale param and input have \"\n                           \"incompatible shapes. This is probably due to \"\n                           \"misspecified arguments.\")\n    broadcast_dims_scale = tuple(broadcast_dims_scale)\n\n    block_type, approx = self._get_block_type(\n        params, approx, self.default_scale_and_shift_approximation,\n        self._scale_and_shift_approx_to_block_types)\n\n    block = self._register_block(params, block_type(\n        self,\n        broadcast_dims_scale,\n        broadcast_dims_shift=broadcast_dims_shift,\n        has_shift=has_shift),\n                                 reuse=reuse)\n    block.register_additional_tower(inputs, outputs)\n\n    self._add_uses(params, 1)\n\n  def register_categorical_predictive_distribution(self,\n                                                   logits,\n                                                   seed=None,\n                                                   targets=None,\n                                                   name=None,\n                                                   coeff=1.0,\n                                                   reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a categorical predictive distribution.\n\n    Corresponds to losses computed using\n    tf.nn.sparse_softmax_cross_entropy_with_logits.\n\n    Note that this is distinct from\n    register_multi_bernoulli_predictive_distribution and should not be confused\n    with it.\n\n    Args:\n      logits: The logits of the distribution (i.e. its parameters). The first\n        dimension must be the batch size.\n      seed: The seed for the RNG (for debugging) (Default: None)\n      targets: (OPTIONAL) The targets for the loss function.  Only required if\n        one wants to use the \"empirical Fisher\" instead of the true Fisher\n        (which is controlled by the 'estimation_mode' to the optimizer).\n        (Default: None)\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the\n        log prob loss associated with this distribution. The Fisher will be\n        multiplied by the corresponding factor. This is NOT equivalent to\n        changing the temperature of the distribution since we don't renormalize\n        the log prob in the objective function. (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an\n        additional mini-batch/tower of inputs to the loss-function/predictive\n        distribution (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n    \"\"\"\n    loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,\n                                                   seed=seed)\n    self._register_loss_function(loss, logits,\n                                 \"categorical_predictive_distribution\",\n                                 name=name, coeff=coeff, reuse=reuse)\n\n  def register_softmax_cross_entropy_loss(self,\n                                          logits,\n                                          seed=None,\n                                          targets=None,\n                                          name=None,\n                                          coeff=1.0,\n                                          reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a softmax cross-entropy loss function.\n\n    Corresponds to losses computed using\n    tf.nn.sparse_softmax_cross_entropy_with_logits.\n\n    Note that this is distinct from register_sigmoid_cross_entropy_loss and\n    should not be confused with it. It is similar to\n    register_categorical_predictive_distribution but without the explicit\n    probabilistic interpretation. It behaves identically for now.\n\n    Args:\n      logits: The logits of the distribution (i.e. its parameters). The first\n        dimension must be the batch size.\n      seed: The seed for the RNG (for debugging) (Default: None)\n      targets: (OPTIONAL) The targets for the loss function.  Only required if\n        one wants to use the \"empirical Fisher\" instead of the true Fisher\n        (which is controlled by the 'estimation_mode' to the optimizer).\n        (Default: None)\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the\n        loss function by. (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an\n        additional mini-batch/tower of inputs to the loss-function/predictive\n        distribution (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n    \"\"\"\n    loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,\n                                                   seed=seed)\n    self._register_loss_function(loss, logits,\n                                 \"sparse_softmax_cross_entropy_loss\",\n                                 name=name, coeff=coeff, reuse=reuse)\n\n  def register_normal_predictive_distribution(self,\n                                              mean,\n                                              var=0.5,\n                                              seed=None,\n                                              targets=None,\n                                              name=None,\n                                              coeff=1.0,\n                                              reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a normal predictive distribution.\n\n    This corresponds to a squared error loss of the form\n       coeff/(2*var) * ||target - mean||^2\n\n    Args:\n      mean: A tensor defining the mean vector of the distribution. The first\n        dimension must be the batch size.\n      var: float. The variance of the distribution. Note that the default value\n        of 0.5 corresponds to a standard squared error loss coeff*||target -\n        prediction||^2. If you want your squared error loss to be of the form\n        0.5*coeff*||target - prediction||^2 you should use var=1.0.\n        (Default: 0.5)\n      seed: The seed for the RNG (for debugging) (Default: None)\n      targets: (OPTIONAL) The targets for the loss function.  Only required if\n        one wants to use the \"empirical Fisher\" instead of the true Fisher\n        (which is controlled by the 'estimation_mode' to the optimizer).\n        (Default: None)\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the\n        log prob loss associated with this distribution. The Fisher will be\n        multiplied by the corresponding factor. In general this is NOT\n        equivalent to changing the temperature of the distribution, but in the\n        case of normal distributions it may be. (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, this adds 'mean' and 'var' as an\n        additional mini-batch/tower of inputs to the loss-function/predictive\n        distribution (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n    \"\"\"\n    loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,\n                                            seed=seed)\n    self._register_loss_function(loss, mean,\n                                 \"normal_predictive_distribution\",\n                                 name=name, coeff=coeff, reuse=reuse)\n\n  def register_squared_error_loss(self,\n                                  prediction,\n                                  seed=None,\n                                  targets=None,\n                                  name=None,\n                                  coeff=1.0,\n                                  reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a squared error loss function.\n\n    This assumes the squared error loss of the form ||target - prediction||^2,\n    averaged across the mini-batch. If your loss uses a coefficient of 0.5\n    (tf.nn.l2_loss does this, for example) you need to set the \"coeff\" argument\n    to reflect this.\n\n    Args:\n      prediction: The prediction made by the network (i.e. its output). The\n        first dimension must be the batch size.\n      seed: The seed for the RNG (for debugging) (Default: None)\n      targets: (OPTIONAL) The targets for the loss function.  Only required if\n        one wants to use the \"empirical Fisher\" instead of the true Fisher\n        (which is controlled by the 'estimation_mode' to the optimizer).\n        (Default: None)\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the\n        loss function by. (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, this adds 'prediction' as an\n        additional mini-batch/tower of inputs to the loss-function/predictive\n        distribution (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n    \"\"\"\n    loss = lf.NormalMeanNegativeLogProbLoss(prediction, var=0.5,\n                                            targets=targets,\n                                            seed=seed)\n    self._register_loss_function(loss, prediction,\n                                 \"squared_error_loss\",\n                                 name=name, coeff=coeff, reuse=reuse)\n\n  def register_multi_bernoulli_predictive_distribution(self,\n                                                       logits,\n                                                       seed=None,\n                                                       targets=None,\n                                                       name=None,\n                                                       coeff=1.0,\n                                                       reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a multi-Bernoulli predictive distribution.\n\n    Corresponds to losses computed using\n    tf.nn.sigmoid_cross_entropy_with_logits.\n\n    Note that this is distinct from\n    register_categorical_predictive_distribution and should not be confused\n    with it.\n\n\n    Args:\n      logits: The logits of the distribution (i.e. its parameters). The first\n        dimension must be the batch size.\n      seed: The seed for the RNG (for debugging) (Default: None)\n      targets: (OPTIONAL) The targets for the loss function.  Only required if\n        one wants to use the \"empirical Fisher\" instead of the true Fisher\n        (which is controlled by the 'estimation_mode' to the optimizer).\n        (Default: None)\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the\n        log prob loss associated with this distribution. The Fisher will be\n        multiplied by the corresponding factor. This is NOT equivalent to\n        changing the temperature of the distribution since we don't renormalize\n        the log prob in the objective function. (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an\n        additional mini-batch/tower of inputs to the loss-function/predictive\n        distribution (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n    \"\"\"\n    loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,\n                                                seed=seed)\n    self._register_loss_function(loss, logits,\n                                 \"multi_bernoulli_predictive_distribution\",\n                                 name=name, coeff=coeff, reuse=reuse)\n\n  def register_sigmoid_cross_entropy_loss(self,\n                                          logits,\n                                          seed=None,\n                                          targets=None,\n                                          name=None,\n                                          coeff=1.0,\n                                          reuse=VARIABLE_SCOPE):\n    \"\"\"Registers a sigmoid cross-entropy loss function.\n\n    Corresponds to losses computed using\n    tf.nn.sigmoid_cross_entropy_with_logits.\n\n    Note that this is distinct from register_softmax_cross_entropy_loss and\n    should not be confused with it. It is similar to\n    register_multi_bernoulli_predictive_distribution but without the explicit\n    probabilistic interpretation. It behaves identically for now.\n\n    Args:\n      logits: The logits tensor. The first dimension must be the batch size.\n      seed: The seed for the RNG (for debugging) (Default: None)\n      targets: (OPTIONAL) The targets for the loss function.  Only required if\n        one wants to use the \"empirical Fisher\" instead of the true Fisher\n        (which is controlled by the 'estimation_mode' to the optimizer).\n        (Default: None)\n      name: (OPTIONAL) str or None. Unique name for this loss function. If None,\n        a new name is generated. (Default: None)\n      coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the\n        loss function by. (Default: 1.0)\n      reuse: (OPTIONAL) bool or str.  If True, this adds 'logits' as an\n        additional mini-batch/tower of inputs to the loss-function/predictive\n        distribution (which must have already been registered). If\n        \"VARIABLE_SCOPE\", use tf.get_variable_scope().reuse.\n        (Default: \"VARIABLE_SCOPE\")\n    \"\"\"\n    loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,\n                                                seed=seed)\n    self._register_loss_function(loss, logits,\n                                 \"sigmoid_cross_entropy_loss\",\n                                 name=name, coeff=coeff, reuse=reuse)\n\n  def make_or_get_factor(self, cls, args):\n    \"\"\"Insert 'cls(args)' into 'self.fisher_factors' if not already present.\n\n    Wraps constructor in 'tf.variable_scope()' to ensure variables constructed\n    in 'cls.__init__' are placed under this LayerCollection's scope.\n\n    Args:\n      cls: Class that implements FisherFactor.\n      args: Tuple of arguments to pass into 'cls's constructor. Must be\n        hashable.\n\n    Returns:\n      Instance of 'cls' found in self.fisher_factors.\n    \"\"\"\n    # TODO(b/123190346): Should probably change the args list to be keyworded\n    # instead of positional.  Note that this would require making changes in\n    # each FisherBlock's call to make_or_get_factor.\n    try:\n      hash(args)\n    except TypeError:\n      raise TypeError(\n          (\"Unable to use (cls, args) = ({}, {}) as a key in \"\n           \"LayerCollection.fisher_factors. The pair cannot be hashed.\").format(\n               cls, args))\n\n    key = cls, args\n    if key not in self.fisher_factors:\n      with tf.variable_scope(self._var_scope):\n        self.fisher_factors[key] = cls(*args)\n    return self.fisher_factors[key]\n\n  @contextmanager\n  def as_default(self):\n    \"\"\"Sets this LayerCollection as the default.\"\"\"\n    set_default_layer_collection(self)\n    yield\n    set_default_layer_collection(None)\n"
  },
  {
    "path": "kfac/python/ops/linear_operator.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Extra functionality we need for LinearOperators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import utils\n\nlinalg = tf.linalg\n\n\nclass LinearOperatorExtras(object):  # pylint: disable=missing-docstring\n\n  def matmul(self, x, adjoint=False, adjoint_arg=False, name=\"matmul\"):  # pylint: disable=missing-docstring\n\n    with self._name_scope(name):\n      if isinstance(x, tf.IndexedSlices):\n        return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)\n\n      x = tf.convert_to_tensor(x, name=\"x\")\n      self._check_input_dtype(x)\n\n      self_dim = -2 if adjoint else -1\n      arg_dim = -1 if adjoint_arg else -2\n      tf.TensorShape(self.shape[self_dim]).assert_is_compatible_with(\n          x.get_shape()[arg_dim])\n\n      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)\n\n  def matmul_right(self, x, adjoint=False, adjoint_arg=False, name=\"matmul\"):  # pylint: disable=missing-docstring\n\n    with self._name_scope(name):\n\n      if isinstance(x, tf.IndexedSlices):\n        return self._matmul_right_sparse(\n            x, adjoint=adjoint, adjoint_arg=adjoint_arg)\n\n      x = tf.convert_to_tensor(x, name=\"x\")\n      self._check_input_dtype(x)\n\n      self_dim = -1 if adjoint else -2\n      arg_dim = -2 if adjoint_arg else -1\n      tf.TensorShape(self.shape[self_dim]).assert_is_compatible_with(\n          x.get_shape()[arg_dim])\n\n      return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)\n\n\nclass LinearOperatorFullMatrix(LinearOperatorExtras,  # pylint: disable=missing-docstring\n                               linalg.LinearOperatorFullMatrix):\n\n  def _matmul_right(self, x, adjoint=False, adjoint_arg=False):\n    return linalg.matmul(\n        x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)\n\n  def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):\n    raise NotImplementedError\n\n  def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):\n    assert not adjoint and not adjoint_arg\n    return utils.matmul_sparse_dense(x, self._matrix)\n\n\nclass LinearOperatorDiag(LinearOperatorExtras,  # pylint: disable=missing-docstring\n                         linalg.LinearOperatorDiag):\n\n  def _matmul_right(self, x, adjoint=False, adjoint_arg=False):\n    diag_mat = tf.conj(self._diag) if adjoint else self._diag\n    x = linalg.adjoint(x) if adjoint_arg else x\n    return diag_mat * x\n\n  def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):\n    diag_mat = tf.conj(self._diag) if adjoint else self._diag\n    assert not adjoint_arg\n    return utils.matmul_diag_sparse(diag_mat, x)\n\n  def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):\n    raise NotImplementedError\n"
  },
  {
    "path": "kfac/python/ops/loss_functions.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Loss functions to be used by LayerCollection.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\n\n# Dependency imports\nimport six\nimport tensorflow.compat.v1 as tf\nimport tensorflow_probability as tfp\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass LossFunction(object):\n  \"\"\"Abstract base class for loss functions.\n\n  Note that unlike typical loss functions used in neural networks these are\n  summed and not averaged across cases in the batch, since this is what the\n  users of this class (FisherEstimator and MatrixVectorProductComputer) will\n  be expecting. The implication of this is that you will may want to\n  normalize things like Fisher-vector products by the batch size when you\n  use this class.  It depends on the use case.\n  \"\"\"\n\n  @abc.abstractproperty\n  def targets(self):\n    \"\"\"The targets being predicted by the model.\n\n    Returns:\n      None or Tensor of appropriate shape for calling self._evaluate() on.\n    \"\"\"\n    pass\n\n  @abc.abstractproperty\n  def inputs(self):\n    \"\"\"The inputs to the loss function (excluding the targets).\"\"\"\n    pass\n\n  def evaluate(self):\n    \"\"\"Evaluate the loss function on the targets.\"\"\"\n    if self.targets is not None:\n      # We treat the targets as \"constant\".  It's only the inputs that get\n      # \"back-propped\" through.\n      return self._evaluate(tf.stop_gradient(self.targets))\n    else:\n      raise Exception(\"Cannot evaluate losses with unspecified targets.\")\n\n  @abc.abstractmethod\n  def _evaluate(self, targets):\n    \"\"\"Evaluates the negative log probability of the targets.\n\n    Args:\n      targets: Tensor that distribution can calculate log_prob() of.\n\n    Returns:\n      negative log probability of each target, summed across all targets.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_ggn(self, vector):\n    \"\"\"Right-multiply a vector by the GGN.\n\n    Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)\n    of the loss function with respect to its inputs.\n\n    Args:\n      vector: The vector to multiply.  Must be the same shape(s) as the\n        'inputs' property.\n\n    Returns:\n      The vector right-multiplied by the GGN.  Will be of the same shape(s)\n      as the 'inputs' property.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_ggn_factor(self, vector):\n    \"\"\"Right-multiply a vector by a factor B of the GGN.\n\n    Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)\n    of the loss function with respect to its inputs.  Typically this will be\n    block-diagonal across different cases in the batch, since the loss function\n    is typically summed across cases.\n\n    Note that B can be any matrix satisfying B * B^T = G where G is the GGN,\n    but will agree with the one used in the other methods of this class.\n\n    Args:\n      vector: The vector to multiply.  Must be of the shape given by the\n        'ggn_factor_inner_shape' property.\n\n    Returns:\n      The vector right-multiplied by B.  Will be of the same shape(s) as the\n      'inputs' property.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_ggn_factor_transpose(self, vector):\n    \"\"\"Right-multiply a vector by the transpose of a factor B of the GGN.\n\n    Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)\n    of the loss function with respect to its inputs.  Typically this will be\n    block-diagonal across different cases in the batch, since the loss function\n    is typically summed across cases.\n\n    Note that B can be any matrix satisfying B * B^T = G where G is the GGN,\n    but will agree with the one used in the other methods of this class.\n\n    Args:\n      vector: The vector to multiply.  Must be the same shape(s) as the\n        'inputs' property.\n\n    Returns:\n      The vector right-multiplied by B^T.  Will be of the shape given by the\n      'ggn_factor_inner_shape' property.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_ggn_factor_replicated_one_hot(self, index):\n    \"\"\"Right-multiply a replicated-one-hot vector by a factor B of the GGN.\n\n    Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)\n    of the loss function with respect to its inputs.  Typically this will be\n    block-diagonal across different cases in the batch, since the loss function\n    is typically summed across cases.\n\n    A 'replicated-one-hot' vector means a tensor which, for each slice along the\n    batch dimension (assumed to be dimension 0), is 1.0 in the entry\n    corresponding to the given index and 0 elsewhere.\n\n    Note that B can be any matrix satisfying B * B^T = G where G is the GGN,\n    but will agree with the one used in the other methods of this class.\n\n    Args:\n      index: A tuple representing in the index of the entry in each slice that\n        is 1.0. Note that len(index) must be equal to the number of elements\n        of the 'ggn_factor_inner_shape' tensor minus one.\n\n    Returns:\n      The vector right-multiplied by B^T. Will be of the same shape(s) as the\n      'inputs' property.\n    \"\"\"\n    pass\n\n  @abc.abstractproperty\n  def ggn_factor_inner_shape(self):\n    \"\"\"The shape of the tensor returned by multiply_ggn_factor.\"\"\"\n    pass\n\n  @abc.abstractproperty\n  def ggn_factor_inner_static_shape(self):\n    \"\"\"Static version of ggn_factor_inner_shape.\"\"\"\n    pass\n\n  @property\n  def dtype(self):\n    if isinstance(self.inputs, (list, tuple)):\n      return self.inputs[0].dtype\n    return self.inputs.dtype\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass NegativeLogProbLoss(LossFunction):\n  \"\"\"Abstract base class for loss functions that are negative log probs.\"\"\"\n\n  def __init__(self, seed=None):\n    self._default_seed = seed\n    super(NegativeLogProbLoss, self).__init__()\n\n  @property\n  def inputs(self):\n    return self.params\n\n  @abc.abstractproperty\n  def params(self):\n    \"\"\"Parameters to the underlying distribution.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_fisher(self, vector):\n    \"\"\"Right-multiply a vector by the Fisher.\n\n    Args:\n      vector: The vector to multiply.  Must be the same shape(s) as the\n        'inputs' property.\n\n    Returns:\n      The vector right-multiplied by the Fisher.  Will be of the same shape(s)\n      as the 'inputs' property.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_fisher_factor(self, vector):\n    \"\"\"Right-multiply a vector by a factor B of the Fisher.\n\n    Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-\n    product of gradients) with respect to the parameters of the underlying\n    probability distribution (whose log-prob defines the loss). Typically this\n    will be block-diagonal across different cases in the batch, since the\n    distribution is usually (but not always) conditionally iid across different\n    cases.\n\n    Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,\n    but will agree with the one used in the other methods of this class.\n\n    Args:\n      vector: The vector to multiply.  Must be of the shape given by the\n        'fisher_factor_inner_shape' property.\n\n    Returns:\n      The vector right-multiplied by B. Will be of the same shape(s) as the\n      'inputs' property.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_fisher_factor_transpose(self, vector):\n    \"\"\"Right-multiply a vector by the transpose of a factor B of the Fisher.\n\n    Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-\n    product of gradients) with respect to the parameters of the underlying\n    probability distribution (whose log-prob defines the loss). Typically this\n    will be block-diagonal across different cases in the batch, since the\n    distribution is usually (but not always) conditionally iid across different\n    cases.\n\n    Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,\n    but will agree with the one used in the other methods of this class.\n\n    Args:\n      vector: The vector to multiply.  Must be the same shape(s) as the\n        'inputs' property.\n\n    Returns:\n      The vector right-multiplied by B^T.  Will be of the shape given by the\n      'fisher_factor_inner_shape' property.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def multiply_fisher_factor_replicated_one_hot(self, index):\n    \"\"\"Right-multiply a replicated-one-hot vector by a factor B of the Fisher.\n\n    Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-\n    product of gradients) with respect to the parameters of the underlying\n    probability distribution (whose log-prob defines the loss). Typically this\n    will be block-diagonal across different cases in the batch, since the\n    distribution is usually (but not always) conditionally iid across different\n    cases.\n\n    A 'replicated-one-hot' vector means a tensor which, for each slice along the\n    batch dimension (assumed to be dimension 0), is 1.0 in the entry\n    corresponding to the given index and 0 elsewhere.\n\n    Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,\n    but will agree with the one used in the other methods of this class.\n\n    Args:\n      index: A tuple representing in the index of the entry in each slice that\n        is 1.0. Note that len(index) must be equal to the number of elements\n        of the 'fisher_factor_inner_shape' tensor minus one.\n\n    Returns:\n      The vector right-multiplied by B. Will be of the same shape(s) as the\n      'inputs' property.\n    \"\"\"\n    pass\n\n  @abc.abstractproperty\n  def fisher_factor_inner_shape(self):\n    \"\"\"The shape of the tensor returned by multiply_fisher_factor.\"\"\"\n    pass\n\n  @abc.abstractproperty\n  def fisher_factor_inner_static_shape(self):\n    \"\"\"Static version of fisher_factor_inner_shape.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def sample(self, seed):\n    \"\"\"Sample 'targets' from the underlying distribution.\"\"\"\n    pass\n\n  def evaluate_on_sample(self, seed=None):\n    \"\"\"Evaluates the log probability on a random sample.\n\n    Args:\n      seed: int or None. Random seed for this draw from the distribution.\n\n    Returns:\n      Log probability of sampled targets, summed across examples.\n    \"\"\"\n    if seed is None:\n      seed = self._default_seed\n    # We treat the targets as \"constant\".  It's only the inputs that get\n    # \"back-propped\" through.\n    return self._evaluate(tf.stop_gradient(self.sample(seed)))\n\n\nclass NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):\n  \"\"\"Base class for neg log prob losses whose inputs are 'natural' parameters.\n\n  We will take the GGN of the loss to be the Fisher associated with the\n  distribution, which also happens to be equal to the Hessian for this class\n  of loss functions.  See here: https://arxiv.org/abs/1412.1193\n\n  'Natural parameters' are defined for exponential-family models. See for\n  example: https://en.wikipedia.org/wiki/Exponential_family\n  \"\"\"\n\n  def multiply_ggn(self, vector):\n    return self.multiply_fisher(vector)\n\n  def multiply_ggn_factor(self, vector):\n    return self.multiply_fisher_factor(vector)\n\n  def multiply_ggn_factor_transpose(self, vector):\n    return self.multiply_fisher_factor_transpose(vector)\n\n  def multiply_ggn_factor_replicated_one_hot(self, index):\n    return self.multiply_fisher_factor_replicated_one_hot(index)\n\n  @property\n  def ggn_factor_inner_shape(self):\n    return self.fisher_factor_inner_shape\n\n  @property\n  def ggn_factor_inner_static_shape(self):\n    return self.fisher_factor_inner_shape\n\n\nclass DistributionNegativeLogProbLoss(NegativeLogProbLoss):\n  \"\"\"Base class for neg log prob losses that use the TF Distribution classes.\"\"\"\n\n  def __init__(self, seed=None):\n    super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)\n\n  @abc.abstractproperty\n  def dist(self):\n    \"\"\"The underlying tfp.distributions.Distribution.\"\"\"\n    pass\n\n  def _evaluate(self, targets):\n    return -tf.reduce_sum(self.dist.log_prob(targets))\n\n  def sample(self, seed):\n    return self.dist.sample(seed=seed)\n\n\nclass NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,\n                                    NaturalParamsNegativeLogProbLoss):\n  \"\"\"Neg log prob loss for a normal distribution parameterized by a mean vector.\n\n\n  Note that the covariance is treated as a constant 'var' times the identity.\n  Also note that the Fisher for such a normal distribution with respect the mean\n  parameter is given by:\n\n     F = (1/var) * I\n\n  See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.\n  \"\"\"\n\n  def __init__(self, mean, var=0.5, targets=None, seed=None):\n    assert isinstance(var, float)  # variance must be a constant float\n\n    self._mean = mean\n    self._var = var\n    self._targets = targets\n    super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)\n\n  @property\n  def targets(self):\n    return self._targets\n\n  @property\n  def dist(self):\n    return tfp.distributions.Normal(loc=self._mean, scale=tf.sqrt(self._var))\n\n  @property\n  def params(self):\n    return self._mean\n\n  def multiply_fisher(self, vector):\n    return (1. / self._var) * vector\n\n  def multiply_fisher_factor(self, vector):\n    return self._var**-0.5 * vector\n\n  def multiply_fisher_factor_transpose(self, vector):\n    return self.multiply_fisher_factor(vector)  # it's symmetric in this case\n\n  def multiply_fisher_factor_replicated_one_hot(self, index):\n    assert len(index) == 1, \"Length of index was {}\".format(len(index))\n    ones_slice = tf.expand_dims(\n        tf.ones(tf.shape(self._mean)[:1], dtype=self._mean.dtype), axis=-1)\n    output_slice = self._var**-0.5 * ones_slice\n    return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),\n                                 index[0])\n\n  @property\n  def fisher_factor_inner_shape(self):\n    return tf.shape(self._mean)\n\n  @property\n  def fisher_factor_inner_static_shape(self):\n    return self._mean.shape\n\n\nclass NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):\n  \"\"\"Negative log prob loss for a normal distribution with mean and variance.\n\n  This class parameterizes a multivariate normal distribution with n independent\n  dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not\n  assume the variance is held constant. The Fisher Information for n = 1\n  is given by,\n\n  F = [[1 / variance,                0],\n       [           0, 0.5 / variance^2]]\n\n  where the parameters of the distribution are concatenated into a single\n  vector as [mean, variance]. For n > 1, the mean parameter vector is\n  concatenated with the variance parameter vector.\n\n  See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.\n  \"\"\"\n\n  def __init__(self, mean, variance, targets=None, seed=None):\n    assert len(mean.shape) == 2, \"Expect 2D mean tensor.\"\n    assert len(variance.shape) == 2, \"Expect 2D variance tensor.\"\n    self._mean = mean\n    self._variance = variance\n    self._targets = targets\n    super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)\n\n  @property\n  def targets(self):\n    return self._targets\n\n  @property\n  def dist(self):\n    return tfp.distributions.Normal(\n        loc=self._mean, scale=tf.sqrt(self._variance))\n\n  @property\n  def params(self):\n    return self._mean, self._variance\n\n  def _concat(self, mean, variance):\n    return tf.concat([mean, variance], axis=-1)\n\n  def _split(self, params):\n    return tf.split(params, 2, axis=-1)\n\n  @property\n  def _fisher_mean(self):\n    return 1. / self._variance\n\n  @property\n  def _fisher_mean_factor(self):\n    return 1. / tf.sqrt(self._variance)\n\n  @property\n  def _fisher_var(self):\n    return 1. / (2 * tf.square(self._variance))\n\n  @property\n  def _fisher_var_factor(self):\n    return 1. / (tf.sqrt(2.) * self._variance)\n\n  def multiply_fisher(self, vecs):\n    mean_vec, var_vec = vecs\n    return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)\n\n  def multiply_fisher_factor(self, vecs):\n    mean_vec, var_vec = self._split(vecs)\n    return (self._fisher_mean_factor * mean_vec,\n            self._fisher_var_factor * var_vec)\n\n  def multiply_fisher_factor_transpose(self, vecs):\n    mean_vec, var_vec = vecs\n    return self._concat(self._fisher_mean_factor * mean_vec,\n                        self._fisher_var_factor * var_vec)\n\n  def multiply_fisher_factor_replicated_one_hot(self, index):\n    assert len(index) == 1, \"Length of index was {}\".format(len(index))\n    index = index[0]\n\n    if index < int(self._mean.shape[-1]):\n      # Index corresponds to mean parameter.\n      mean_slice = self._fisher_mean_factor[:, index]\n      mean_slice = tf.expand_dims(mean_slice, axis=-1)\n      mean_output = insert_slice_in_zeros(mean_slice, 1, int(\n          self._mean.shape[1]), index)\n      var_output = tf.zeros_like(mean_output)\n    else:\n      index -= int(self._mean.shape[-1])\n      # Index corresponds to variance parameter.\n      var_slice = self._fisher_var_factor[:, index]\n      var_slice = tf.expand_dims(var_slice, axis=-1)\n      var_output = insert_slice_in_zeros(var_slice, 1,\n                                         int(self._variance.shape[1]), index)\n      mean_output = tf.zeros_like(var_output)\n\n    return mean_output, var_output\n\n  @property\n  def fisher_factor_inner_shape(self):\n    return tf.concat(\n        [tf.shape(self._mean)[:-1], 2 * tf.shape(self._mean)[-1:]], axis=0)\n\n  @property\n  def fisher_factor_inner_static_shape(self):\n    shape = self._mean.shape.as_list()\n    return tf.TensorShape(shape[-1:] + [2 * shape[-1]])\n\n  def multiply_ggn(self, vector):\n    raise NotImplementedError()\n\n  def multiply_ggn_factor(self, vector):\n    raise NotImplementedError()\n\n  def multiply_ggn_factor_transpose(self, vector):\n    raise NotImplementedError()\n\n  def multiply_ggn_factor_replicated_one_hot(self, index):\n    raise NotImplementedError()\n\n  @property\n  def ggn_factor_inner_shape(self):\n    raise NotImplementedError()\n\n  @property\n  def ggn_factor_inner_static_shape(self):\n    raise NotImplementedError()\n\n\nclass CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,\n                                           NaturalParamsNegativeLogProbLoss):\n  \"\"\"Neg log prob loss for a categorical distribution parameterized by logits.\n\n\n  Note that the Fisher (for a single case) of a categorical distribution, with\n  respect to the natural parameters (i.e. the logits), is given by:\n\n  F = diag(p) - p*p^T\n\n  where p = softmax(logits).  F can be factorized as F = B * B^T where\n\n  B = diag(q) - p*q^T\n\n  where q is the entry-wise square root of p. This is easy to verify using the\n  fact that q^T*q = 1.\n  \"\"\"\n\n  def __init__(self, logits, targets=None, seed=None):\n    \"\"\"Instantiates a CategoricalLogitsNegativeLogProbLoss.\n\n    Args:\n      logits: Tensor of shape [batch_size, output_size]. Parameters for\n        underlying distribution.\n      targets: None or Tensor of shape [batch_size]. Each elements contains an\n        index in [0, output_size).\n      seed: int or None. Default random seed when sampling.\n    \"\"\"\n    self._logits = logits\n    self._targets = targets\n    super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)\n\n  @property\n  def targets(self):\n    return self._targets\n\n  @property\n  def dist(self):\n    return tfp.distributions.Categorical(logits=self._logits)\n\n  @property\n  def _probs(self):\n    return self.dist.probs_parameter()\n\n  @property\n  def _sqrt_probs(self):\n    return tf.sqrt(self._probs)\n\n  @property\n  def params(self):\n    return self._logits\n\n  def multiply_fisher(self, vector):\n    probs = self._probs\n    return vector * probs - probs * tf.reduce_sum(\n        vector * probs, axis=-1, keepdims=True)\n\n  def multiply_fisher_factor(self, vector):\n    probs = self._probs\n    sqrt_probs = self._sqrt_probs\n    return sqrt_probs * vector - probs * tf.reduce_sum(\n        sqrt_probs * vector, axis=-1, keepdims=True)\n\n  def multiply_fisher_factor_transpose(self, vector):\n    probs = self._probs\n    sqrt_probs = self._sqrt_probs\n    return sqrt_probs * vector - sqrt_probs * tf.reduce_sum(\n        probs * vector, axis=-1, keepdims=True)\n\n  def multiply_fisher_factor_replicated_one_hot(self, index):\n    assert len(index) == 1, \"Length of index was {}\".format(len(index))\n    probs = self._probs\n    sqrt_probs = self._sqrt_probs\n    sqrt_probs_slice = tf.expand_dims(sqrt_probs[:, index[0]], -1)\n    padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,\n                                         int(sqrt_probs.shape[1]), index[0])\n    return padded_slice - probs * sqrt_probs_slice\n\n  @property\n  def fisher_factor_inner_shape(self):\n    return tf.shape(self._logits)\n\n  @property\n  def fisher_factor_inner_static_shape(self):\n    return self._logits.shape\n\n\nclass MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,\n                                        NaturalParamsNegativeLogProbLoss):\n  \"\"\"Neg log prob loss for multiple Bernoulli distributions param'd by logits.\n\n  Represents N independent Bernoulli distributions where N = len(logits). Its\n  Fisher Information matrix is given by,\n\n  F = diag(p * (1-p))\n  p = sigmoid(logits)\n\n  As F is diagonal with positive entries, its factor B is,\n\n  B = diag(sqrt(p * (1-p)))\n  \"\"\"\n\n  def __init__(self, logits, targets=None, seed=None):\n    self._logits = logits\n    self._targets = targets\n    super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)\n\n  @property\n  def targets(self):\n    return self._targets\n\n  @property\n  def dist(self):\n    return tfp.distributions.Bernoulli(logits=self._logits)\n\n  @property\n  def _probs(self):\n    return self.dist.probs_parameter()\n\n  @property\n  def params(self):\n    return self._logits\n\n  def multiply_fisher(self, vector):\n    return self._probs * (1 - self._probs) * vector\n\n  def multiply_fisher_factor(self, vector):\n    return tf.sqrt(self._probs * (1 - self._probs)) * vector\n\n  def multiply_fisher_factor_transpose(self, vector):\n    return self.multiply_fisher_factor(vector)  # it's symmetric in this case\n\n  def multiply_fisher_factor_replicated_one_hot(self, index):\n    assert len(index) == 1, \"Length of index was {}\".format(len(index))\n    probs_slice = tf.expand_dims(self._probs[:, index[0]], -1)\n    output_slice = tf.sqrt(probs_slice * (1 - probs_slice))\n    return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),\n                                 index[0])\n\n  @property\n  def fisher_factor_inner_shape(self):\n    return tf.shape(self._logits)\n\n  @property\n  def fisher_factor_inner_static_shape(self):\n    return self._logits.shape\n\n\ndef insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):\n  \"\"\"Inserts slice into a larger tensor of zeros.\n\n  Forms a new tensor which is the same shape as slice_to_insert, except that\n  the dimension given by 'dim' is expanded to the size given by 'dim_size'.\n  'position' determines the position (index) at which to insert the slice within\n  that dimension.\n\n  Assumes slice_to_insert.shape[dim] = 1.\n\n  Args:\n    slice_to_insert: The slice to insert.\n    dim: The dimension which to expand with zeros.\n    dim_size: The new size of the 'dim' dimension.\n    position: The position of 'slice_to_insert' in the new tensor.\n\n  Returns:\n    The new tensor.\n\n  Raises:\n    ValueError: If the slice's shape at the given dim is not 1.\n  \"\"\"\n  slice_shape = slice_to_insert.shape\n  if slice_shape[dim] != 1:\n    raise ValueError(\"Expected slice_to_insert.shape to have {} dim of 1, but \"\n                     \"was {}\".format(dim, slice_to_insert.shape[dim]))\n\n  before = [0] * int(len(slice_shape))\n  after = before[:]\n  before[dim] = position\n  after[dim] = dim_size - position - 1\n\n  return tf.pad(slice_to_insert, list(zip(before, after)))\n\n\nclass OnehotCategoricalLogitsNegativeLogProbLoss(\n    CategoricalLogitsNegativeLogProbLoss):\n  \"\"\"Neg log prob loss for a categorical distribution with onehot targets.\n\n  Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying\n  distribution is OneHotCategorical as opposed to Categorical.\n  \"\"\"\n\n  @property\n  def dist(self):\n    return tfp.distributions.OneHotCategorical(logits=self._logits)\n"
  },
  {
    "path": "kfac/python/ops/op_queue.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Helper for choosing which op to run next in a distributed setting.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\n\nclass OpQueue(object):\n  \"\"\"Class for choosing which Op to run next.\n\n  Constructs an infinitely repeating sequence of Ops in shuffled order.\n\n  In K-FAC, this can be used to distribute inverse update operations among\n  workers.\n  \"\"\"\n\n  def __init__(self, ops, seed=None):\n    \"\"\"Initializes an OpQueue.\n\n    Args:\n      ops: list of TensorFlow Ops. Ops to be selected from. All workers must\n        initialize with the same set of ops.\n      seed: int or None. Random seed used when shuffling order of ops.\n    \"\"\"\n    self._ops_by_name = {op.name: op for op in ops}\n\n    # Construct a (shuffled) Dataset with Op names.\n    op_names = tf.convert_to_tensor(list(sorted(op.name for op in ops)))\n    op_names_dataset = (\n        tf.data.Dataset.from_tensor_slices(op_names).shuffle(\n            len(ops), seed=seed).repeat())\n    self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()\n\n  @property\n  def ops(self):\n    \"\"\"Ops this OpQueue can return in next_op().\"\"\"\n    return self._ops_by_name.values()\n\n  def next_op(self, sess):\n    \"\"\"Chooses which op to run next.\n\n    Note: This call will make a call to sess.run().\n\n    Args:\n      sess: tf.Session.\n\n    Returns:\n      Next Op chosen from 'ops'.\n    \"\"\"\n    # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')\n    # returns a str.\n    next_op_name = sess.run(self._next_op_name).decode('ascii')\n    return self._ops_by_name[next_op_name]\n"
  },
  {
    "path": "kfac/python/ops/optimizer.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"The KFAC optimizer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops import curvature_matrix_vector_products as cmvp\nfrom kfac.python.ops import estimator as est\nfrom kfac.python.ops import fisher_factors as ff\nfrom kfac.python.ops import utils as utils\n\nip = utils.ip\nip_p = utils.ip_p\nsprod = utils.sprod\nsprod_p = utils.sprod_p\n\n# If True we the damping contribution is included in the quadratic model for\n# the purposes of computing qmodel_change in rho (the reduction ratio used in\n# the LM damping adjustment method). Note that the extra damping from the\n# \"l2_reg\" argument is always included.\n_INCLUDE_DAMPING_IN_QMODEL_CHANGE = False\n\n\ndef set_global_constants(include_damping_in_qmodel_change=None):\n  \"\"\"Sets various global constants used by the classes in this module.\"\"\"\n  global _INCLUDE_DAMPING_IN_QMODEL_CHANGE\n\n  if include_damping_in_qmodel_change is not None:\n    _INCLUDE_DAMPING_IN_QMODEL_CHANGE = include_damping_in_qmodel_change\n\n\nclass KfacOptimizer(tf.train.GradientDescentOptimizer):\n  \"\"\"The KFAC Optimizer (https://arxiv.org/abs/1503.05671).\"\"\"\n\n  def __init__(self,\n               learning_rate,\n               damping,\n               layer_collection,\n               cov_ema_decay=0.95,\n               var_list=None,\n               momentum=0.9,\n               momentum_type=\"adam\",\n               use_weight_decay=False,\n               weight_decay_coeff=0.1,\n               qmodel_update_rescale=None,\n               norm_constraint=None,\n               name=\"KFAC\",\n               estimation_mode=\"gradients\",\n               colocate_gradients_with_ops=True,\n               batch_size=None,\n               placement_strategy=None,\n               compute_params_stats=False,\n               adapt_damping=False,\n               update_damping_immediately=True,\n               is_chief=True,\n               prev_train_batch=None,\n               loss=None,\n               loss_fn=None,\n               min_damping=1e-8,  # this value is somewhat arbitrary\n               l2_reg=0.0,\n               damping_adaptation_decay=0.95,\n               damping_adaptation_interval=5,\n               damping_decrease_rho_threshold=0.75,\n               damping_increase_rho_threshold=0.25,\n               precon_damping_mult=1.0,\n               use_passed_loss=True,\n               train_batch=None,\n               print_logs=False,\n               tf_replicator=None,\n               dtype=\"float32\",\n               **kwargs):\n    \"\"\"Initializes the K-FAC optimizer with the given settings.\n\n      NOTE: this is a base class for K-FAC optimizers that offers full control\n      over the execution of K-FAC's various ops.  For a more fool-proof /\n      automated version see for example PeriodicInvCovUpdateKfacOpt.\n\n      Also, please keep in mind that while the K-FAC code loosely conforms to\n      TensorFlow's Optimizer API it can't be used naively as a \"drop in\n      replacement\" for basic classes like MomentumOptimizer.  Using it\n      properly with SyncReplicasOptimizer, for example, requires special care.\n      When using it with Distribution Strategy, unlike common practice, K-FAC\n      expects a loss tensor that is normalized by the per-replica batch size,\n      and *not* by the total batch size (like you may see in TF Distribution\n      Strategy tutorials). Regardless of whether you are using estimator,\n      strategy, or a normal custom training loop, you should pass in the same\n      loss.\n\n      See the various examples in the \"examples\" directory for a guide about\n      how to use K-FAC in various contexts and various systems, like\n      TF-Estimator. See in particular the \"convnet\" example.  google/examples\n      also contains an example using TPUEstimator.\n\n    Args:\n      learning_rate: float or 0D Tensor. The base learning rate for the\n          optimizer. Must be set to None if using one of the 'qmodel'\n          momentum_type values.\n      damping: float or 0D Tensor. This quantity times the identity matrix is\n          (approximately) added to the curvature matrix (i.e. the Fisher or GGN)\n          before it is inverted multiplied by the gradient when computing the\n          (raw) update. This quantity should match the scale of the objective,\n          so that if you put a multiplier on your loss you should apply the\n          same multiplier to the damping. Roughly speaking, larger values\n          constrain the update vector to a smaller region around zero, which\n          we want to do when our local quadratic model is a less trustworthy\n          local approximation of the true objective.  The damping value is\n          closely related to the trust region radius and to the classical\n          Tikhonov regularization method. If the `adapt_damping` argument is\n          True then this value is used only as an initial value for the\n          adaptation method.\n      layer_collection: The layer collection object, which holds the Fisher\n          blocks, Kronecker factors, and losses associated with the\n          graph.  The layer_collection cannot be modified after KfacOptimizer's\n          initialization.\n      cov_ema_decay: The decay factor used when calculating the\n          covariance estimate moving averages. (Default: 0.95)\n      var_list: Optional list or tuple of variables to train. Defaults to\n          tf.trainable_variables.\n      momentum: The momentum decay constant to use. Only applies when\n          momentum_type is 'regular' or 'adam'. (Default: 0.9)\n      momentum_type: The type of momentum to use in this optimizer, one of\n          'regular', 'adam', 'qmodel', or 'qmodel_fixedmu'. 'regular' gives\n          standard momentum. 'adam' gives a style of momentum reminisent\n          of the Adam method, which seems to work better in practice.\n          'qmodel' makes the optimizer perform automatic control of both the\n          learning rate and momentum using a quadratic model based method\n          (see _compute_qmodel_hyperparams for more details). 'qmodel_fixedmu'\n          is similar to 'qmodel' but only controls the learning rate.\n          (Default: 'adam')\n      use_weight_decay: If True, explicit \"weight decay\" is performed by K-FAC.\n          Note that this is distinct from L2 regularization, and corresponds to\n          optimizing a regularized version of the loss passed to minimize(),\n          where the regularization term added is related to the \"Fisher-Rao\n          norm\". See https://openreview.net/pdf?id=B1lz-3Rct7 for more details.\n          Note that using this feature won't change the loss function you pass\n          to minimize(), and thus the loss you report will not correspond\n          precisely to what K-FAC is optimizing. (Default: False)\n      weight_decay_coeff: The coefficient to use for weight decay (see above).\n          (Default: 0.1)\n      qmodel_update_rescale: float or None.  An additional multiplier to apply\n          to the update computed by the quadratic model based adjustment\n          methods. If None it will behave like a value of 1.0. (Default: None)\n      norm_constraint: float or Tensor. If specified, the update is scaled down\n          so that its approximate squared Fisher norm v^T F v is at most the\n          specified value. May only be used with momentum type 'regular'.  See\n          the docstring for the method _clip_updates() for a more detailed\n          explanation of this feature. (Default: None)\n      name: The name for this optimizer. (Default: 'KFAC')\n      estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be\n          'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN',\n          'exact', or 'exact_GGN'. See the doc-string for FisherEstimator\n          (in estimator.py) for more a more detailed description of these\n          options. (Default: 'gradients').\n      colocate_gradients_with_ops: Whether we should request gradients we\n          compute in the estimator be colocated with their respective ops.\n          (Default: True)\n      batch_size: The size of the mini-batch. Only needed when `momentum_type`\n          == 'qmodel' or when `compute_params_stats` is True. Note that when\n          using data parallelism where the model graph and optimizer are\n          replicated across multiple devices, this should be the per-replica\n          batch size. An example of this is sharded data on the TPU, where\n          batch_size should be set to the total batch size divided by the\n          number of shards. (Default: None)\n      placement_strategy: string or None. Device placement strategy used when\n          creating variables, and various ops. Can be None, 'round_robin', or\n          'replica_round_robin'. 'round_robin' supports round-robin placement of\n          various ops on lists of provided devices. 'replica_round_robin' does\n          something similar but over shards/replicas instead, and only works\n          in certain 'replicated' contexts (e.g. TPUEstimator).  The details of\n          the different placement strategies are controlled by additional\n          keyword arguments that can be passed to this class, and which are\n          described in the different placement mixin classes in placement.py.\n          (Default: None)\n      compute_params_stats: Bool. If True, we compute the first order version\n          of the statistics computed to estimate the Fisher/GGN. These\n          correspond to the `variables` method in a one-to-one fashion.  They\n          are available via the `params_stats` property.  When estimation_mode\n          is 'empirical', this will correspond to the standard parameter\n          gradient on the loss. (Default: False)\n      adapt_damping: `Boolean`. If True we adapt the damping according to the\n          Levenberg-Marquardt rule described in Section 6.5 of the original\n          K-FAC paper. The details of this scheme are controlled by various\n          additional arguments below. Also some of these arguments are extra\n          pieces of information, such as the loss, needed by the method. Note\n          that unless using a convenience subclass like\n          PeriodicInvCovUpdateKfacOpt the damping adaptation op must be\n          executed by the user (like the cov and inv ops). This op is returned\n          by the maybe_pre_update_adapt_damping() method. (Default: False)\n      update_damping_immediately: Damping adjustment strategy. If True then the\n          damping is updated in the same optimizer minimize call as\n          `(step+1) % damping_adaptation_interval == 0`, immediately after the\n          parameter update is performed. If False then the damping is updated\n          in the next step. If True then it is assumed that the apply_gradients\n          op will safely update the model before returning; it is recommended\n          to only resource variables in this case. (Default: True)\n      is_chief: `Boolean`, `True` if the worker is chief. (Default: True)\n      prev_train_batch: Training mini-batch used in the previous step. This\n          will be used to evaluate loss by calling `loss_fn(prev_train_batch)`\n          when damping adaptation is used. (Default: None)\n      loss: `Tensor` the model loss, used as the pre-update loss in adaptive\n          damping. Also used for the built-in log printing. When using\n          Distribution Strategy, unlike common Distribution Strategy practice,\n          this loss tensor should by normalized by the per-replica batch size\n          and NOT the total batch size. (Default: None)\n      loss_fn: `function` that takes as input training data tensor and returns\n          a scalar loss. Only needed if using damping adaptation. When using\n          Distribution Strategy, unlike common Distribution Strategy practice,\n          the loss should by normalized by the per-replica batch size and NOT\n          the total batch size. (Default: None)\n      min_damping: `float`, Minimum value the damping parameter can take. Note\n          that the default value of 1e-8 is quite arbitrary, and you may have\n          to adjust this up or down for your particular problem. If you are\n          using a non-zero value of l2_reg you *may* be able to set this to\n          zero. (Default: 1e-8)\n      l2_reg: `float` or 0D Tensor. Set this value to tell the optimizer what L2\n          regularization coefficient you are using (if any). Note the\n          coefficient appears in the regularizer as coeff / 2 * sum(param**2),\n          as the thing you multiply tf.nn.l2(param) by. This will be essentially\n          added to the minimum damping, but also included in the qmodel change\n          computations (used for adjusting the damping) even when\n          _INCLUDE_DAMPING_IN_QMODEL_CHANGE is False. Note that the user is\n          still responsible for adding regularization to the loss.\n          (Default: 0.0)\n      damping_adaptation_decay: `float`. The `damping` parameter is\n          multiplied by the `damping_adaptation_decay` every\n          `damping_adaptation_interval` number of iterations. (Default: 0.99)\n      damping_adaptation_interval: `int`. Number of steps in between\n          updating the `damping` parameter. Note that damping is adapted at\n          the step where (step+1) % damping_adaptation_interval == 0,\n          (or immediately at the start of the next step by\n          maybe_pre_update_adapt_damping() if update_damping_immediately is\n          False). (Default: 5)\n      damping_decrease_rho_threshold: `int`. The threshold for rho above which\n          we decrease the damping when using automatic damping adaptation.\n          (Default: 0.75)\n      damping_increase_rho_threshold: `int`. The threshold for rho below which\n          we increase the damping when using automatic damping adaptation.\n          (Default: 0.25)\n      precon_damping_mult: `float`. A multiplier used on the damping value\n          passed to the preconditioner (vs the quadratic model when using\n          momentum_type 'qmodel'). (Default: 1.0)\n      use_passed_loss: `Boolean`. If True we use the loss tensor passed in by\n          the user (via minimze() or compute_gradients() or the set_loss()\n          method) in damping adaptation scheme, instead of calling loss_fn()\n          a second time for this. This is more efficient but may not always be\n          desired. (Default: True)\n      train_batch: Training mini-batch used in the current step. This\n          will be used to evaluate loss by calling `loss_fn(train_batch)`\n          when damping adaptation is used and `use_passed_loss` is False.\n          (Default: None)\n      print_logs: `Boolean`. If True, we print some logging info using\n          tf.print after each iteration. This is done in the method\n          _maybe_print_logging_info, which we encourage you to modify in order\n          to add whatever you want. (Default: False)\n      tf_replicator: A Replicator object or None. If not None, K-FAC will set\n          itself up to work inside of the provided TF-Replicator object.\n          (Default: None)\n      dtype: TF dtype or string representing one. dtype used for scalar\n          properties (rho, etc). (Default: \"float32\")\n      **kwargs: Arguments to be passed to specific placement strategy mixin.\n          Check `placement.RoundRobinPlacementMixin` for example.\n\n    Raises:\n      ValueError: If the momentum type is unsupported.\n      ValueError: If clipping is used with momentum type other than 'regular'.\n      ValueError: If no losses have been registered with layer_collection.\n      ValueError: If momentum is non-zero and momentum_type is not 'regular'\n          or 'adam'.\n    \"\"\"\n    dtype = tf.dtypes.as_dtype(dtype)\n    self._dtype = dtype\n\n    self._layers = layer_collection\n\n    self._colocate_gradients_with_ops = colocate_gradients_with_ops\n\n    momentum_type = momentum_type.lower()\n    legal_momentum_types = [\"regular\", \"adam\", \"qmodel\", \"qmodel_fixedmu\"]\n\n    if momentum_type not in legal_momentum_types:\n      raise ValueError(\"Unsupported momentum type {}. Must be one of {}.\"\n                       .format(momentum_type, legal_momentum_types))\n    if momentum_type not in [\"regular\", \"adam\"] and norm_constraint is not None:\n      raise ValueError(\"Update clipping is only supported with momentum \"\n                       \"type 'regular' and 'adam'.\")\n    if momentum_type == \"qmodel\" and momentum is not None:\n      raise ValueError(\"Momentum must be None if using a momentum_type \"\n                       \"'qmodel'.\")\n    self._momentum_type = momentum_type\n    self._momentum = momentum\n\n    self._use_weight_decay = use_weight_decay\n    self._weight_decay_coeff = weight_decay_coeff\n\n    self._norm_constraint = norm_constraint\n    self._batch_size = batch_size\n    self._placement_strategy = placement_strategy\n\n    # Damping adaptation parameters\n    self._adapt_damping = adapt_damping\n\n    if self._adapt_damping:\n      with tf.variable_scope(name):\n        self._damping = tf.get_variable(\n            \"damping\", initializer=lambda: tf.constant(damping, dtype=dtype),\n            trainable=False, use_resource=True, dtype=dtype)\n    else:\n      self._damping = damping\n\n    self._update_damping_immediately = update_damping_immediately\n    self._is_chief = is_chief\n    self._prev_train_batch = prev_train_batch\n    self._loss_tensor = loss\n    self._loss_fn = loss_fn\n    self._damping_adaptation_decay = damping_adaptation_decay\n    self._damping_adaptation_interval = damping_adaptation_interval\n    self._omega = (\n        self._damping_adaptation_decay**self._damping_adaptation_interval)\n    self._min_damping = min_damping\n    self._use_passed_loss = use_passed_loss\n    if not use_passed_loss and train_batch is None:\n      raise ValueError(\"Must pass in train_batch if used_passed_loss is false.\")\n\n    self._damping_decrease_rho_threshold = damping_decrease_rho_threshold\n    self._damping_increase_rho_threshold = damping_increase_rho_threshold\n\n    self._train_batch = train_batch\n\n    self._print_logs = print_logs\n\n    self._l2_reg = l2_reg\n\n    self._precon_damping_mult = precon_damping_mult\n\n    if self._momentum_type.startswith(\"qmodel\"):\n      if learning_rate is not None:\n        raise ValueError(\"'learning_rate' must be set to None if using one of \"\n                         \"the 'qmodel' momentum types.\")\n      if qmodel_update_rescale is not None:\n        learning_rate = qmodel_update_rescale\n      else:\n        learning_rate = 1.0\n    else:\n      if learning_rate is None:\n        raise ValueError(\"'learning_rate' must *not* be set to None unless \"\n                         \"using one of the 'qmodel' momentum types.\")\n      if qmodel_update_rescale is not None:\n        raise ValueError(\"'qmodel_update_rescale' must be None unless using \"\n                         \"one of the 'qmodel' momentum types.\")\n    self._qmodel_update_rescale = qmodel_update_rescale\n\n    with tf.variable_scope(name):\n      nan_init = lambda: tf.constant(float(\"nan\"), dtype=dtype)\n      # We store rho only for possible logging purposes.\n      self._rho = tf.get_variable(\n          \"rho\", initializer=nan_init, dtype=dtype,\n          trainable=False, use_resource=True)\n      self._prev_loss = tf.get_variable(\n          \"prev_loss\", initializer=nan_init, dtype=dtype,\n          trainable=False, use_resource=True)\n      self._qmodel_learning_rate = tf.get_variable(\n          \"qmodel_learning_rate\", initializer=nan_init, dtype=dtype,\n          trainable=False, use_resource=True)\n      self._qmodel_momentum = tf.get_variable(\n          \"qmodel_momentum\", initializer=nan_init, dtype=dtype,\n          trainable=False, use_resource=True)\n      self._qmodel_change = tf.get_variable(\n          \"qmodel_change\", initializer=nan_init, dtype=dtype,\n          trainable=False, use_resource=True)\n\n      self._counter = tf.get_variable(\n          \"counter\", dtype=tf.int64, shape=(), trainable=False,\n          initializer=tf.zeros_initializer, use_resource=True,\n          aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)\n\n      variables = var_list or tf.trainable_variables()\n\n      if tf_replicator is not None or tf.distribute.has_strategy():\n        def _get_sanitized_name(var_name):\n          return re.sub(r\"replica_\\d+_\", \"\", var_name)\n\n        # This tells K-FAC's libraries that we are using TF-Replicator with this\n        # particular Replicator object.\n        utils.set_global_constants(tf_replicator=tf_replicator)\n\n        # We need to sanitize the names of the variables that K-FAC creates\n        # so they are the same between replicas.\n        ff.set_global_constants(get_sanitized_name_fn=_get_sanitized_name)\n\n      self._fisher_est = est.make_fisher_estimator(\n          placement_strategy=placement_strategy,\n          variables=variables,\n          cov_ema_decay=cov_ema_decay,\n          damping=self._damping * self._precon_damping_mult,\n          layer_collection=self.layers,\n          exps=(-1,),\n          estimation_mode=estimation_mode,\n          colocate_gradients_with_ops=self._colocate_gradients_with_ops,\n          compute_params_stats=compute_params_stats,\n          batch_size=batch_size,\n          **kwargs)\n\n    super(KfacOptimizer, self).__init__(learning_rate, name=name)\n\n  def get_cov_vars(self):\n    \"\"\"Returns all covaraiance varaiables.\"\"\"\n    return self._fisher_est.get_cov_vars()\n\n  def get_inv_vars(self):\n    \"\"\"Returns all inverse computation related varaiables.\"\"\"\n    return self._fisher_est.get_inv_vars()\n\n  @property\n  def factors(self):\n    return self._fisher_est.factors\n\n  @property\n  def registered_variables(self):\n    return self._fisher_est.variables\n\n  @property\n  def layers(self):\n    return self._layers\n\n  @property\n  def mat_type(self):\n    return self._fisher_est.mat_type\n\n  @property\n  def damping(self):\n    if self._adapt_damping:\n      return tf.identity(self._damping)\n    else:\n      return tf.convert_to_tensor(self._damping)\n\n  @property\n  def damping_adaptation_interval(self):\n    return self._damping_adaptation_interval\n\n  @property\n  def learning_rate(self):\n    if self._momentum_type.startswith(\"qmodel\"):\n      return self._learning_rate * tf.identity(self._qmodel_learning_rate)\n    else:\n      return tf.convert_to_tensor(self._learning_rate)\n\n  @property\n  def momentum(self):\n    if self._momentum_type.startswith(\"qmodel\"):\n      return tf.identity(self._qmodel_momentum)\n    else:\n      return tf.convert_to_tensor(self._momentum)\n\n  @property\n  def rho(self):\n    return tf.identity(self._rho)\n\n  @property\n  def qmodel_change(self):\n    return tf.identity(self._qmodel_change)\n\n  @property\n  def counter(self):\n    return tf.identity(self._counter)\n\n  @property\n  def params_stats(self):\n    return self._fisher_est.params_stats\n\n  def set_loss(self, loss):\n    # Use this method if you have overridden both the minimize method and\n    # compute_gradients method but still want K-FAC to know the loss value\n    # (which is required for damping adaptation).\n    self._loss_tensor = loss\n\n  def _maybe_print_logging_info(self):\n    if not self._print_logs:\n      return tf.no_op()\n\n    p = []\n    p.append((\"=========================================================\",))\n    p.append((\"Iteration:\", self.counter))\n    p.append((\"mini-batch loss =\", self._loss_tensor))\n    p.append((\"learning_rate =\", self.learning_rate, \"| momentum =\",\n              self.momentum))\n    p.append((\"damping =\", self.damping, \"| rho =\", self.rho,\n              \"| qmodel_change =\", self.qmodel_change))\n    p.append((\"=========================================================\",))\n\n    return utils.multiline_print(p)\n\n  def make_vars_and_create_op_thunks(self):\n    \"\"\"Make vars and create op thunks.\n\n    Returns:\n      cov_update_thunks: List of cov update thunks. Corresponds one-to-one with\n        the list of factors given by the \"factors\" property.\n      inv_update_thunks: List of inv update thunks. Corresponds one-to-one with\n        the list of factors given by the \"factors\" property.\n    \"\"\"\n    scope = self.get_name() + \"/\" + self._fisher_est.name\n    return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)\n\n  def create_ops_and_vars_thunks(self):\n    \"\"\"Create thunks that make the ops and vars on demand.\n\n    This function returns 4 lists of thunks: cov_variable_thunks,\n    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.\n\n    The length of each list is the number of factors and the i-th element of\n    each list corresponds to the i-th factor (given by the \"factors\" property).\n\n    Note that the execution of these thunks must happen in a certain\n    partial order.  The i-th element of cov_variable_thunks must execute\n    before the i-th element of cov_update_thunks (and also the i-th element\n    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks\n    must execute before the i-th element of inv_update_thunks.\n\n    TL;DR (oversimplified): Execute the thunks according to the order that\n    they are returned.\n\n    Returns:\n      cov_variable_thunks: A list of thunks that make the cov variables.\n      cov_update_thunks: A list of thunks that make the cov update ops.\n      inv_variable_thunks: A list of thunks that make the inv variables.\n      inv_update_thunks: A list of thunks that make the inv update ops.\n    \"\"\"\n    scope = self.get_name() + \"/\" + self._fisher_est.name\n    return self._fisher_est.create_ops_and_vars_thunks(scope=scope)\n\n  def check_var_list(self, var_list):\n    if set(var_list) != set(self.registered_variables):\n      raise ValueError(\"var_list doesn't match with set of Fisher-estimating \"\n                       \"variables (i.e. those that were registered).\")\n\n  @staticmethod\n  def _scale_loss(loss_value):\n    # tf.compat.v1.train.Optimizer uses this method to account for the Estimator\n    # + Distribution Strategy (DS) case. DS wants a scaled loss and to aggregate\n    # gradients via a sum. Estimator provides an unscaled loss by default. So,\n    # this method would divide the loss by num_replicas. For our optimizer, we\n    # require users to pass in an unscaled loss, so we do not want this method\n    # to alter Estimator's input when it's used with DS.\n    return loss_value\n\n  def minimize(self,\n               loss,\n               global_step=None,\n               var_list=None,\n               gate_gradients=tf.train.Optimizer.GATE_OP,\n               aggregation_method=None,\n               colocate_gradients_with_ops=True,\n               name=None,\n               grad_loss=None,\n               **kwargs):\n    # This method has the same general arguments as the minimize methods in\n    # standard optimizers do.\n    # With most optimizers used with Distribution Strategy (DS), the user is\n    # expected to scale their loss by 1.0 / global_batch_size, then DS\n    # aggregates the gradients via a sum. We expect users to pass in a loss that\n    # is normalized by the per-replica batch size only. This is so we can\n    # handle the Estimator and DS cases in a consistent way. As a side effect,\n    # this means each replica must have the same per-replica batch size.\n\n    if var_list is None:\n      var_list = self.registered_variables\n    else:\n      self.check_var_list(var_list)\n\n    return super(KfacOptimizer, self).minimize(\n        loss,\n        global_step=global_step,\n        var_list=var_list,\n        gate_gradients=gate_gradients,\n        aggregation_method=aggregation_method,\n        colocate_gradients_with_ops=colocate_gradients_with_ops,\n        name=name,\n        grad_loss=grad_loss,\n        **kwargs)\n\n  def compute_gradients(self,\n                        loss,\n                        var_list=None,\n                        gate_gradients=tf.train.Optimizer.GATE_OP,\n                        aggregation_method=None,\n                        colocate_gradients_with_ops=True,\n                        grad_loss=None,\n                        **kwargs):\n    # This method has the same general arguments as the minimize methods in\n    # standard optimizers do. Unlike the compute_gradient method for typical\n    # optimizer implementations, this one performs cross-replica syncronization\n    # automatically when under one the supported replicated contexts, and so\n    # use of things like CrossShardOptimizer is unessesary (and wasteful).\n\n    if var_list is not None:\n      self.check_var_list(var_list)\n\n    grads_and_vars = super(KfacOptimizer, self).compute_gradients(\n        loss=loss,\n        var_list=var_list,\n        gate_gradients=gate_gradients,\n        aggregation_method=aggregation_method,\n        colocate_gradients_with_ops=colocate_gradients_with_ops,\n        grad_loss=grad_loss,\n        **kwargs)\n\n    # When using the TF Keras fused BatchNormalization implementation, in some\n    # cases the gradient shape is ?. KFAC needs the gradient shape in at least\n    # two cases: when registering a layer as generic, and when computing the\n    # qmodel. The gradient should have the same shape as the variable, so when\n    # any dimension is None we set the shape ourselves.\n    for grad, var in grads_and_vars:\n      if len(grad.shape) and not all(grad.shape.as_list()):\n        grad.set_shape(var.shape)\n\n    grads, vars_ = list(zip(*grads_and_vars))\n    grads = utils.all_average(grads)\n\n    return tuple(zip(grads, vars_))\n\n  def _is_damping_adaptation_time(self):\n    # Note that we update damping at the step right before the end of the\n    # interval, instead of at the beginning of the next interval. This is\n    # so it properly lines up with the periodic inverse updates (i.e. happens\n    # immediately before them.)\n    return tf.equal(tf.mod(self.counter + 1,\n                           self._damping_adaptation_interval),\n                    0)\n\n  def _is_just_after_damping_adaptation_time(self):\n    is_just_after = tf.equal(\n        tf.mod(self.counter, self._damping_adaptation_interval), 0)\n\n    return tf.logical_and(is_just_after, tf.not_equal(self.counter, 0))\n\n  def _maybe_update_prev_loss(self):\n    if self._adapt_damping:\n      should_update_prev_loss = self._is_damping_adaptation_time()\n\n      def update_prev_loss():\n        loss = self._loss_tensor if self._use_passed_loss else self._loss_fn(\n            self._train_batch)\n        loss = utils.all_average(loss)\n        return tf.group(utils.smart_assign(self._prev_loss, loss,\n                                           force_cast=True))\n\n      maybe_update_prev_loss_op = tf.cond(\n          should_update_prev_loss,\n          update_prev_loss,\n          tf.no_op)\n\n      return maybe_update_prev_loss_op\n    else:\n      return tf.no_op()\n\n  def maybe_pre_update_adapt_damping(self):\n    \"\"\"Maybe adapt the damping according to the built-in scheme.\n\n    Unless using a convenience class like PeriodicInvCovUpdateKfacOpt the op\n    returned by this function should be run every sess.run call, preferably\n    before the inv ops (using a control dependency).\n\n    Returns:\n      An op that applies the specified gradients, and also updates the counter\n      variable.\n    \"\"\"\n    if (not self._adapt_damping or not self._is_chief or\n        self._update_damping_immediately):\n      return tf.no_op()\n\n    # We update the damping on the iteration that is technically after\n    # where we compute qmodel_change.  However, it should happen before\n    # anything else does, so it's as if we computed it on the previous\n    # iteration.  The only reason we do it this way and not on the\n    # actual iteration is due to weirdness related to parameter servers\n    # or possibly just non-resource variables. Essentially, the model\n    # variables won't be updated and so we can't properly compute\n    # prev_batch_loss until the next sess.run() call.\n    should_update_damping = self._is_just_after_damping_adaptation_time()\n\n    maybe_update_damping = tf.cond(\n        should_update_damping,\n        self._update_damping,\n        tf.no_op)\n    return maybe_update_damping\n\n  def _maybe_post_update_adapt_damping(self):\n    if not self._update_damping_immediately or not self._adapt_damping:\n      return tf.no_op()\n\n    should_update_damping = self._is_damping_adaptation_time()\n\n    maybe_update_damping = tf.cond(\n        should_update_damping,\n        self._update_damping,\n        tf.no_op)\n    return maybe_update_damping\n\n  def apply_gradients(self, grads_and_vars, *args, **kwargs):\n    \"\"\"Apply updates to variables.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs.\n      *args: Additional arguments for super.apply_gradients.\n      **kwargs: Additional keyword arguments for super.apply_gradients.\n\n    Returns:\n      An op that applies the specified gradients, and also updates the counter\n      variable.\n    \"\"\"\n    maybe_update_prev_loss = self._maybe_update_prev_loss()\n\n    with tf.control_dependencies([maybe_update_prev_loss]):\n      # In Python 3, grads_and_vars can be a zip() object which can only be\n      # iterated over once. By converting it to a list, we ensure that it can be\n      # iterated over more than once.\n      grads_and_vars = list(grads_and_vars)\n\n      with tf.variable_scope(self.get_name()):\n        # Compute raw update step (self._learning_rate not yet applied).\n        # Note that this function also updates the velocity vectors.\n        raw_updates_and_vars = self._compute_raw_update_steps(grads_and_vars)\n\n      if self._use_weight_decay:\n        raw_updates_and_vars = self._add_weight_decay(raw_updates_and_vars)\n\n      if tf.distribute.has_strategy():\n        # Distribution Strategy (DS) expects users to pass in loss /\n        # global_batch_size to minimize. We require users not to do this, so our\n        # code can consistently deal with input in the single device, Estimator,\n        # and DS cases. However, the _distributed_apply call in\n        # super(...).apply_gradients(...) will perform a sum over replicas to\n        # aggregate the gradients. Therefore, we divide by the number of\n        # replicas so that the scaling of the update applied to the variables\n        # is correct.\n        num_replicas = tf.distribute.get_strategy().num_replicas_in_sync\n        raw_updates_and_vars = [(update/num_replicas, var)\n                                for update, var in raw_updates_and_vars]\n\n      # Update trainable variables with this step, applying self._learning_rate.\n      apply_op = super(KfacOptimizer, self).apply_gradients(\n          raw_updates_and_vars, *args, **kwargs)\n\n      with tf.control_dependencies([apply_op]):\n        maybe_post_update_damping_op = self._maybe_post_update_adapt_damping()\n\n        with tf.control_dependencies([maybe_post_update_damping_op]):\n          maybe_print_logging_info = self._maybe_print_logging_info()\n          with tf.control_dependencies([maybe_print_logging_info]):\n            # Update the main counter\n            return tf.group(\n                utils.smart_assign(self._counter, 1, assign_fn=tf.assign_add))\n\n  def _add_weight_decay(self, grads_and_vars):\n    \"\"\"Applies weight decay.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs.\n\n    Returns:\n      List of (gradient, variable) pairs.\n    \"\"\"\n    return [(grad + self._weight_decay_coeff * tf.stop_gradient(var), var)\n            for grad, var in grads_and_vars]\n\n  def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):\n    \"\"\"Computes the squared (approximate) Fisher norm of the updates.\n\n    This is defined as v^T F v, where F is the approximate Fisher matrix\n    as computed by the estimator, and v = F^{-1} g, where g is the gradient.\n    This is computed efficiently as v^T g.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs.\n      precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.\n        Must be the result of calling `self._multiply_preconditioner`\n        on `grads_and_vars`.\n\n    Returns:\n      Scalar representing the squared norm.\n\n    Raises:\n      ValueError: if the two list arguments do not contain the same variables,\n        in the same order.\n    \"\"\"\n    return ip_p(grads_and_vars, precon_grads_and_vars)\n\n  def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):\n    \"\"\"Computes the scale factor for the update to satisfy the norm constraint.\n\n    Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,\n    F is the approximate Fisher matrix, and r is the update vector, i.e.\n    -alpha * v, where alpha is the learning rate, and v is the preconditioned\n    gradient.\n\n    This is based on Section 5 of Ba et al., Distributed Second-Order\n    Optimization using Kronecker-Factored Approximations. Note that they\n    absorb the learning rate alpha (which they denote eta_max) into the formula\n    for the coefficient, while in our implementation, the rescaling is done\n    before multiplying by alpha. Hence, our formula differs from theirs by a\n    factor of alpha.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs.\n      precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.\n        Must be the result of calling `self._multiply_preconditioner`\n        on `grads_and_vars`.\n\n    Returns:\n      Scalar representing the coefficient which should be applied to the\n      preconditioned gradients to satisfy the norm constraint.\n    \"\"\"\n    sq_norm_grad = self._squared_fisher_norm(grads_and_vars,\n                                             precon_grads_and_vars)\n    sq_norm_up = sq_norm_grad * self._learning_rate**2\n    return tf.minimum(\n        tf.ones(shape=(), dtype=sq_norm_up.dtype),\n        tf.sqrt(self._norm_constraint / sq_norm_up))\n\n  def _clip_updates(self, grads_and_vars, precon_grads_and_vars):\n    \"\"\"Rescales the preconditioned gradients to satisfy the norm constraint.\n\n    Rescales the preconditioned gradients such that the resulting update r\n    (after multiplying by the learning rate) will satisfy the norm constraint.\n    This constraint is that r^T F r <= C, where F is the approximate Fisher\n    matrix, and C is the norm_constraint attribute. See Section 5 of\n    Ba et al., Distributed Second-Order Optimization using Kronecker-Factored\n    Approximations.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs.\n      precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.\n        Must be the result of calling `self._multiply_preconditioner`\n        on `grads_and_vars`.\n\n    Returns:\n      List of (rescaled preconditioned gradient, variable) pairs.\n    \"\"\"\n    coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)\n    return sprod_p(coeff, precon_grads_and_vars)\n\n  def _compute_prev_updates(self, variables):\n    \"\"\"Returns the previous update vector computed using the quadratic model.\n\n    Note that this vector does not include any additional scaling that may have\n    been applied after the quadratic model optimization (i.e. the quantity\n    returned by self.learning_rate).\n\n    Note that this may not actually be the previous update if\n    momentum_type=\"adam\".\n\n    Args:\n      variables: List of variables for which to compute the previous update.\n\n    Returns:\n      List of (previous_update, variable) pairs in the same order as\n      `variables`.\n    \"\"\"\n    # What guarantee do we have that this is the old value and not the\n    # new value?  Remember that control flow doesn't work in TF whenever\n    # non-resource variables are involved.\n    # TODO(b/121245468): Figure out if this is a problem and if not explain why\n    # Or fix it by somehow forcing the slots to use resource variables instead.\n\n    prev_updates = sprod(\n        -1., tuple(self._zeros_slot(var, \"velocity\", self.get_name())\n                   for var in variables))\n    return tuple(zip(prev_updates, variables))\n\n  def _compute_qmodel(self,\n                      raw_updates_and_vars,\n                      prev_updates_and_vars,\n                      grads_and_vars,\n                      should_average_over_replicas=True):\n    \"\"\"Computes the 2 dimensional version of the (exact) quadratic model.\n\n       The two dimesions are the update and the previous update vectors.\n\n       The arguments are all lists of (Tensor, Variable) pairs where the\n       variables are the same and in the same order.\n\n    Args:\n      raw_updates_and_vars: a list of (precond grad, variable) pairs. Raw update\n        proposal to apply to the variables (before scaling by learning rate and\n        addition of velocity/momentum).\n      prev_updates_and_vars: a list of (previous update, variable) pairs.\n        Previous update applied to the variables (includes learning rate and\n        velocity/momentum).\n      grads_and_vars: a list of (gradient, variable) pairs. Gradients for the\n        parameters and the variables that the updates are being applied to. The\n        order of this list must correspond to the order of the other arguments.\n        (Note that this function doesn't actually apply the update.)\n      should_average_over_replicas: a bool. If true, results will be averged\n        over replicas (using utils.all_average). (Default: True)\n\n    Returns:\n      m, c, and b. m is the 2 by 2 matrix representing the quadratic term,\n      c is a 2 by 1 vector representing the linear term, and b is the 2 by 2\n      matrix representing only the contribution of the damping to the quadratic\n      term. These are all multi-dimensional lists (lists of lists) of Tensors.\n    \"\"\"\n\n    # Raw update proposal to apply to the variables (before scaling by learning\n    # rate and addition of velocity/momentum).\n    raw_updates, _ = zip(*raw_updates_and_vars)\n    prev_updates, _ = zip(*prev_updates_and_vars)\n    grads, variables = zip(*grads_and_vars)\n\n    utils.assert_variables_match_pairs_list(\n        raw_updates_and_vars, prev_updates_and_vars,\n        error_message=\"_compute_qmodel raw_updates_and_vars and \"\n        \"prev_updates_and_vars differ.\")\n    utils.assert_variables_match_pairs_list(\n        prev_updates_and_vars, grads_and_vars,\n        error_message=\"_compute_qmodel prev_updates_and_vars and \"\n        \"grads_and_vars differ.\")\n\n    cmvpc = cmvp.CurvatureMatrixVectorProductComputer(\n        self.layers,\n        variables,\n        colocate_gradients_with_ops=self._colocate_gradients_with_ops)\n\n    # Compute the matrix-vector products with the transposed Fisher factor\n    # (or GGN factor)\n    if self.mat_type == \"Fisher\":\n      mft_updates = cmvpc.multiply_fisher_factor_transpose(raw_updates)\n      mft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)\n    elif self.mat_type == \"GGN\" or self.mat_type == \"Empirical_Fisher\":\n      mft_updates = cmvpc.multiply_ggn_factor_transpose(raw_updates)\n      mft_prev_updates = cmvpc.multiply_ggn_factor_transpose(prev_updates)\n\n    batch_size = tf.cast(self._batch_size, dtype=mft_updates[0].dtype)\n\n    damping = tf.cast(self.damping, dtype=raw_updates[0].dtype)\n    b_11 = damping * ip(raw_updates, raw_updates)\n    b_21 = damping * ip(prev_updates, raw_updates)\n    b_22 = damping * ip(prev_updates, prev_updates)\n    b = [[b_11, b_21], [b_21, b_22]]\n\n    # Compute the entries of the 2x2 matrix\n    m_11 = ip(mft_updates, mft_updates) / batch_size\n    m_21 = ip(mft_prev_updates, mft_updates) / batch_size\n    m_22 = (ip(mft_prev_updates, mft_prev_updates)\n            / batch_size)\n    m = [[m_11 + b_11, m_21 + b_21],\n         [m_21 + b_21, m_22 + b_22]]\n\n    if should_average_over_replicas:\n      m = utils.all_average(m)\n\n    c_1 = ip(grads, raw_updates)\n    c_2 = ip(grads, prev_updates)\n\n    c = [[c_1], [c_2]]\n\n    return m, c, b\n\n  @property\n  def _sub_damping_out_qmodel_change_coeff(self):\n    return 1.0 - self._l2_reg / self.damping\n\n  def _compute_qmodel_hyperparams(self, m, c, b, fixed_mu=None):\n    \"\"\"Compute optimal update hyperparameters from the quadratic model.\n\n    More specifically, if L is the loss we minimize a quadratic approximation\n    of L(theta + d) which we denote by qmodel(d) with\n    d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where\n\n      qmodel(d) = (1/2) * d^T * C * d + grad^T*d + L(theta) .\n\n    Unlike in the KL clipping approach we use the non-approximated quadratic\n    model where the curvature matrix C is the true Fisher (or GGN) on the\n    current mini-batch (computed without any approximations beyond mini-batch\n    sampling), with the usual Tikhonov damping/regularization applied,\n\n      C = F + damping * I\n\n    See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of\n    the formula.  See Appendix C for a discussion of the trick of using\n    a factorized Fisher matrix to more efficiently compute the required\n    vector-matrix-vector products.\n\n    Args:\n      m: 2 by 2 matrix representing the quadratic term (a list of list of\n        0D Tensors)\n      c: a 2 by 1 vector representing the linear term (a list of 0D Tensors)\n      b: 2 by 2 matrix representing only the contribution of the damping to the\n        quadratic term\n      fixed_mu: A fixed value of mu to use instead of the optimal one.\n        (Default: None)\n    Returns:\n      (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the\n      quadratic model, and\n      qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)\n                    = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).\n    \"\"\"\n\n    def non_zero_prevupd_case():\n      r\"\"\"Computes optimal (alpha, mu) given non-zero previous update.\n\n      We solve the full 2x2 linear system. See Martens & Grosse (2015),\n      Section 7, definition of $\\alpha^*$ and $\\mu^*$.\n\n      Returns:\n        (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize\n        the quadratic model, and\n        qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).\n      \"\"\"\n      if fixed_mu is None:\n        sol = -1. * _two_by_two_solve(m, c)\n        alpha = sol[0, 0]\n        mu = sol[1, 0]\n\n        if self._qmodel_update_rescale is None:\n          # This is a special formula that takes advantage of the particular\n          # relationship of sol to m and c. It should be equivalent to\n          # _eval_quadratic(m, c, sol) if everything is working properly.\n          qmodel_change = 0.5 * tf.reduce_sum(sol * c)\n        else:\n          sol = self._qmodel_update_rescale * sol\n          qmodel_change = _eval_quadratic(m, c, sol)\n\n        # Subtract out the damping-related penalty\n        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:\n          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff\n                            * _eval_quadratic_no_c(b, sol))\n\n      else:\n        alpha = -1. * (fixed_mu * m[0][1] + c[0][0]) / (m[0][0])\n        mu = fixed_mu\n\n        sol = [[alpha], [mu]]\n\n        if self._qmodel_update_rescale is not None:\n          sol = self._qmodel_update_rescale * tf.convert_to_tensor(sol)\n\n        qmodel_change = _eval_quadratic(m, c, sol)\n\n        # Subtract out the damping-related penalty\n        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:\n          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff\n                            * _eval_quadratic_no_c(b, sol))\n\n      return tf.squeeze(alpha), tf.squeeze(mu), tf.squeeze(qmodel_change)\n\n    def zero_prevupd_case():\n      r\"\"\"Computes optimal (alpha, mu) given all-zero previous update.\n\n      The linear system reduces to 1x1. See Martens & Grosse (2015),\n      Section 6.4, definition of $\\alpha^*$.\n\n      Returns:\n        (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the\n        quadratic model, and\n        qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)\n      \"\"\"\n      alpha = -c[0][0] / m[0][0]\n      if fixed_mu is None:\n        mu = 0.0\n      else:\n        mu = fixed_mu\n\n      mu = tf.cast(mu, dtype=alpha.dtype)\n\n      if self._qmodel_update_rescale is None:\n        # This is a special formula that takes advantage of the particular\n        # relationship of sol to m and c.\n        qmodel_change = 0.5 * alpha * c[0][0]\n\n        # Subtract out the damping-related penalty\n        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:\n          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff\n                            * 0.5 * tf.square(alpha) * b[0][0])\n      else:\n        sol = self._qmodel_update_rescale * alpha\n        qmodel_change = 0.5 * m[0][0] * tf.square(sol) + c[0][0] * sol\n        # Subtract out the damping-related penalty\n        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:\n          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff\n                            * 0.5 * tf.square(sol) * b[0][0])\n\n      return alpha, mu, qmodel_change\n\n    return tf.cond(\n        tf.equal(c[1][0], 0.0),\n        zero_prevupd_case,\n        non_zero_prevupd_case)\n\n  def _compute_approx_qmodel_change(self, updates_and_vars, grads_and_vars):\n    \"\"\"Computes the change in the approximate quadratic model.\n\n    'Approximate' means the quadratic model which uses the approximate\n    Fisher/GGN as the curvature matrix, instead of the exact Fisher/GGN which\n    is used by _compute_qmodel and its dependent methods.\n\n    Args:\n      updates_and_vars: List of (update, variable) pairs.\n      grads_and_vars: List of (gradient, variable) pairs.\n\n    Returns:\n      A 0D Tensor which is the change in the approximate quadratic model.\n    \"\"\"\n\n    quad_term = 0.5*ip_p(updates_and_vars,\n                         self._fisher_est.multiply(updates_and_vars))\n\n    if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:\n      # This isn't quite right, but doing it properly is too awkward.\n      quad_term -= (self._sub_damping_out_qmodel_change_coeff *\n                    0.5 * self.damping*ip_p(updates_and_vars, updates_and_vars))\n    linear_term = ip_p(updates_and_vars, grads_and_vars)\n\n    return quad_term + linear_term\n\n  def _maybe_update_qmodel_change(self, qmodel_change_thunk):\n    \"\"\"Returns an op which updates the qmodel_change variable if it is time to.\n\n    Args:\n      qmodel_change_thunk: A callable which when evaluated returns the qmodel\n        change.\n\n    Returns:\n      An op.\n    \"\"\"\n    def update_qmodel_change():\n      # The tf.group is needed to strip away the value so it can be used\n      # in the cond later.\n      return tf.group(utils.smart_assign(self._qmodel_change,\n                                         tf.squeeze(qmodel_change_thunk()),\n                                         force_cast=True))\n\n    # Note that we compute the qmodel change and store it in a variable so\n    # it can be used at the next sess.run call (where rho will actually be\n    # computed).\n    return tf.cond(self._is_damping_adaptation_time(),\n                   update_qmodel_change, tf.no_op)\n\n  def _multiply_preconditioner(self, vecs_and_vars):\n    return self._fisher_est.multiply_inverse(vecs_and_vars)\n\n  def _get_qmodel_quantities(self, grads_and_vars):\n\n    # Compute \"preconditioned gradient\".\n    precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars)\n\n    var_list = tuple(var for (_, var) in grads_and_vars)\n    prev_updates_and_vars = self._compute_prev_updates(var_list)\n\n    # While it might seem like this call performs needless computations\n    # involving prev_updates_and_vars in the case where it is zero, because\n    # we extract out only the part of the solution that is not zero the rest\n    # of it will not actually be computed by TensorFlow (I think).\n    m, c, b = self._compute_qmodel(\n        precon_grads_and_vars, prev_updates_and_vars, grads_and_vars)\n\n    return precon_grads_and_vars, m, c, b\n\n  def _compute_raw_update_steps(self, grads_and_vars):\n    \"\"\"Computes the raw update steps for the variables given the gradients.\n\n    Note that these \"raw updates\" are further multiplied by\n    -1*self._learning_rate when the update is eventually applied in the\n    superclass (which is GradientDescentOptimizer).\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs.\n\n    Returns:\n      A list of tuples (raw_update, var) where raw_update is the update to\n      the parameter. These updates must be actually used since they carry\n      with them certain control dependencies that need to happen.\n    \"\"\"\n\n    if self._momentum_type == \"regular\":\n      # Compute \"preconditioned\" gradient.\n      precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars)\n\n      # Apply \"KL clipping\" if asked for.\n      if self._norm_constraint is not None:\n        precon_grads_and_vars = self._clip_updates(grads_and_vars,\n                                                   precon_grads_and_vars)\n\n      # Update the velocities and get their values as the \"raw\" updates\n      raw_updates_and_vars = self._update_velocities(precon_grads_and_vars,\n                                                     self._momentum)\n\n      if self._adapt_damping and self._is_chief:\n\n        def compute_qmodel_change():\n          updates_and_vars = sprod_p(-1. * self._learning_rate,\n                                     raw_updates_and_vars)\n          return self._compute_approx_qmodel_change(updates_and_vars,\n                                                    grads_and_vars)\n\n        maybe_update_qmodel_change = self._maybe_update_qmodel_change(\n            compute_qmodel_change)\n\n        with tf.control_dependencies([maybe_update_qmodel_change]):\n          # Making this a tuple is important so that it actually gets evaluated\n          # in the context.\n          return tuple((tf.identity(vec), var)\n                       for (vec, var) in raw_updates_and_vars)\n      else:\n        return raw_updates_and_vars\n\n    elif self._momentum_type == \"adam\":\n      velocities_and_vars = self._update_velocities(grads_and_vars,\n                                                    self._momentum)\n      # The \"preconditioned\" velocity vector is the raw update step.\n      raw_updates_and_vars = self._multiply_preconditioner(velocities_and_vars)\n\n      # Apply \"KL clipping\" if asked for. Note that we are applying this to\n      # the combined preconditioned gradient + velocity, unlike for the\n      # momentum_type = 'regular' case.\n      if self._norm_constraint is not None:\n        raw_updates_and_vars = self._clip_updates(velocities_and_vars,\n                                                  raw_updates_and_vars)\n\n      if self._adapt_damping and self._is_chief:\n        def compute_qmodel_change():\n          # This is a special formula that exploits the structure of the\n          # particular update we are using.  Note that this is using the approx\n          # Fisher as defined by the inverses, which might be stale (perhaps so\n          # stale that they are using an old damping value, which may mess up\n          # the damping adaptation method).\n          return (0.5 * (self._learning_rate**2) *\n                  ip_p(raw_updates_and_vars, velocities_and_vars) -\n                  self._learning_rate * ip_p(raw_updates_and_vars,\n                                             grads_and_vars))\n\n        maybe_update_qmodel_change = self._maybe_update_qmodel_change(\n            compute_qmodel_change)\n\n        with tf.control_dependencies([maybe_update_qmodel_change]):\n          # Making this a tuple is important so that it actually gets evaluated\n          # in the context.\n          return tuple((tf.identity(vec), var)\n                       for (vec, var) in raw_updates_and_vars)\n      else:\n        return raw_updates_and_vars\n\n    elif (self._momentum_type == \"qmodel\"\n          or self._momentum_type == \"qmodel_fixedmu\"):\n\n      precon_grads_and_vars, m, c, b = self._get_qmodel_quantities(\n          grads_and_vars)\n\n      if self._momentum_type == \"qmodel_fixedmu\":\n        fixed_mu = self._momentum\n      else:\n        fixed_mu = None\n\n      # Compute optimal velocity update parameters according to quadratic\n      # model\n      alpha, mu, qmodel_change = self._compute_qmodel_hyperparams(\n          m, c, b, fixed_mu=fixed_mu)\n\n      qmodel_assign_op = tf.group(\n          utils.smart_assign(self._qmodel_change, qmodel_change,\n                             force_cast=True),\n          utils.smart_assign(self._qmodel_learning_rate, -alpha,\n                             force_cast=True),\n          utils.smart_assign(self._qmodel_momentum, mu,\n                             force_cast=True))\n\n      with tf.control_dependencies([qmodel_assign_op]):\n        return self._update_velocities(\n            precon_grads_and_vars, mu, vec_coeff=-alpha)\n\n  # NOTE: the very particular way this function is written is probably important\n  # for it to work correctly with non-resource variables, which are very\n  # unpredictable with regards to control flow.\n  def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):\n    \"\"\"Updates the velocities of the variables with the given vectors.\n\n    Args:\n      vecs_and_vars: List of (vector, variable) pairs.\n      decay: How much to decay the old velocity by.  This is often referred to\n        as the 'momentum constant'.\n      vec_coeff: Coefficient to apply to the vectors before adding them to the\n        velocity.\n\n    Returns:\n      A list of (velocity, var) indicating the new velocity for each var.\n    \"\"\"\n    def _update_velocity(vec, var):\n      velocity = self._zeros_slot(var, \"velocity\", self.get_name())\n      with tf.colocate_with(velocity):\n        # NOTE(mattjj): read/modify/write race condition not suitable for async.\n\n        # Compute the new velocity for this variable.\n        new_velocity = decay * velocity + vec_coeff * vec\n\n        # Save the updated velocity.\n        return (tf.identity(utils.smart_assign(velocity, new_velocity)), var)\n\n    # Go through variable and update its associated part of the velocity vector.\n    return [_update_velocity(vec, var) for vec, var in vecs_and_vars]\n\n  def _get_current_loss(self):\n    if self._update_damping_immediately:\n      return utils.all_average(self._loss_fn(self._train_batch))\n\n    return utils.all_average(self._loss_fn(self._prev_train_batch))\n\n  def _get_prev_loss(self):\n    return tf.identity(self._prev_loss)\n\n  def _update_damping(self):\n    \"\"\"Adapts damping parameter. Check KFAC paper (Section 6.5) for the details.\n\n    The damping parameter is updated according to the Levenberg-Marquardt rule\n    every `self._damping_adaptation_interval` iterations.\n\n    Essentially, the rule computes the reduction ratio \"rho\" and depending on\n    the value either increases lambda, decreases it, or leaves it as is.\n\n    The reduction ratio captures how closely the quadratic approximation to the\n    loss function approximates the actual loss within a trust region. The\n    damping update tries to make the damping as small as possible while\n    maintaining the property that the quadratic model remains a good local\n    approximation to the loss function.\n\n    Returns:\n      An Op to assign newly computed damping value to `self._damping`, and also\n      updates the _rho member.\n    \"\"\"\n    prev_loss = self._get_prev_loss()\n    current_loss = tf.cast(self._get_current_loss(), dtype=prev_loss.dtype)\n\n    loss_change = current_loss - prev_loss\n    rho = loss_change / self._qmodel_change\n\n    should_decrease = tf.math.logical_or(\n        tf.math.logical_and(loss_change < 0, self._qmodel_change > 0),\n        rho > self._damping_decrease_rho_threshold)\n    should_increase = rho < self._damping_increase_rho_threshold\n\n    new_damping = tf.case(\n        [(should_decrease, lambda: self.damping * self._omega),\n         (should_increase, lambda: self.damping / self._omega)],\n        default=lambda: self.damping)\n\n    new_damping = tf.maximum(new_damping, self._min_damping + self._l2_reg)\n\n    return tf.group(utils.smart_assign(self._damping, new_damping),\n                    utils.smart_assign(self._rho, rho, force_cast=True))\n\n\ndef _two_by_two_solve(m, vec):\n  \"\"\"Solve a 2x2 system by direct inversion.\n\n  Args:\n    m: A length 2 list of length 2 lists, is a 2x2 matrix of [[a, b], [c, d]].\n    vec: The length 2 list of length 1 lists, a vector of [e, f].\n\n  Returns:\n    matmul(m^{-1}, vec).\n  \"\"\"\n  a = m[0][0]\n  b = m[0][1]\n  c = m[1][0]\n  d = m[1][1]\n  inv_m_det = 1.0 / (a * d - b * c)\n  m_inverse = [\n      [d * inv_m_det, -b * inv_m_det],\n      [-c * inv_m_det, a * inv_m_det]\n  ]\n  return tf.matmul(m_inverse, vec)\n\n\ndef _eval_quadratic_no_c(m, vec):\n  return 0.5*tf.matmul(tf.matmul(vec, m, transpose_a=True), vec)\n\n\ndef _eval_quadratic(m, c, vec):\n  return _eval_quadratic_no_c(m, vec) + tf.matmul(c, vec, transpose_a=True)\n"
  },
  {
    "path": "kfac/python/ops/placement.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Implements placement strategies for various ops and variables.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport itertools\n\n# Dependency imports\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.util import nest\nfrom kfac.python.ops import utils as utils\n\n\ndef _make_thunk_on_device(func, device):\n  def thunk(*args, **kwargs):\n    with tf.device(device):\n      return func(*args, **kwargs)\n  return thunk\n\n\nclass RoundRobinPlacementMixin(object):\n  \"\"\"Implements round robin placement strategy for ops and variables.\"\"\"\n\n  def __init__(self, cov_devices=None, inv_devices=None, trans_devices=None,\n               **kwargs):\n    \"\"\"Create a RoundRobinPlacementMixin object.\n\n    Args:\n      cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance\n        computations will be placed on these devices in a round-robin fashion.\n        Can be None or empty, which means that no devices are specified.\n      inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion\n        computations will be placed on these devices in a round-robin fashion.\n        Can be None or empty, which means that no devices are specified.\n      trans_devices: Iterable of device strings (e.g. '/gpu:0'). Transformation\n        computations (e.g. multiplying different blocks by the inverse Fisher)\n        will be placed on these devices in a round-robin fashion. Can be None\n        or empty, which means that no devices are specified.\n      **kwargs: Pass through arguments.\n    \"\"\"\n    super(RoundRobinPlacementMixin, self).__init__(**kwargs)\n    self._cov_devices = cov_devices\n    self._inv_devices = inv_devices\n    self._trans_devices = trans_devices\n\n  def _place_and_compute_transformation_thunks(self, thunks, params_list):\n    \"\"\"Computes transformation thunks with round-robin device placement.\n\n    Device placement done in round-robin fashion according to the order of\n    the `blocks` property, using the list `trans_devices` passed in to the\n    constructor.\n\n    Args:\n      thunks: A list of thunks to run. Must be in one to one correspondence\n        with the `blocks` property.\n      params_list: A list of the corresponding parameters. Must be in one to one\n        correspondence with the `blocks` property.\n\n    Returns:\n      A list (in the same order) of the returned results of the thunks, with\n      round-robin device placement applied.\n    \"\"\"\n    del params_list\n\n    if self._trans_devices:\n      results = []\n      for thunk, device in zip(thunks, itertools.cycle(self._trans_devices)):\n        with tf.device(device):\n          results.append(thunk())\n      return results\n    else:\n      return tuple(thunk() for thunk in thunks)\n\n  def create_ops_and_vars_thunks(self, scope=None):\n    \"\"\"Create thunks that make the ops and vars on demand with device placement.\n\n    For each factor, all of that factor's cov variables and their associated\n    update ops will be placed on a particular device.  A new device is chosen\n    for each factor by cycling through list of devices in the\n    `self._cov_devices` attribute. If `self._cov_devices` is `None` then no\n    explicit device placement occurs.\n\n    An analogous strategy is followed for inverse update ops, with the list of\n    devices being given by the `self._inv_devices` attribute.\n\n    Inverse variables on the other hand are not placed on any specific device\n    (they will just use the current the device placement context, whatever\n    that happens to be).  The idea is that the inverse variable belong where\n    they will be accessed most often, which is the device that actually applies\n    the preconditioner to the gradient. The user will be responsible for setting\n    the device context for this.\n\n    This function returns 4 lists of thunks: cov_variable_thunks,\n    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.\n\n    The length of each list is the number of factors and the i-th element of\n    each list corresponds to the i-th factor (given by the \"factors\" property).\n\n    Note that the execution of these thunks must happen in a certain\n    partial order.  The i-th element of cov_variable_thunks must execute\n    before the i-th element of cov_update_thunks (and also the i-th element\n    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks\n    must execute before the i-th element of inv_update_thunks.\n\n    TL;DR (oversimplified): Execute the thunks according to the order that\n    they are returned.\n\n    Args:\n      scope: A string or None.  If None it will be set to the name of this\n        estimator (given by the name property). All variables will be created,\n        and all thunks will execute, inside of a variable scope of the given\n        name. (Default: None)\n\n    Returns:\n      cov_variable_thunks: A list of thunks that make the cov variables.\n      cov_update_thunks: A list of thunks that make the cov update ops.\n      inv_variable_thunks: A list of thunks that make the inv variables.\n      inv_update_thunks: A list of thunks that make the inv update ops.\n    \"\"\"\n    (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,\n     inv_update_thunks_raw) = self._create_ops_and_vars_thunks(scope=scope)\n\n    if self._cov_devices:\n      cov_variable_thunks = []\n      cov_update_thunks = []\n      for cov_variable_thunk, cov_update_thunk, device in zip(\n          cov_variable_thunks_raw, cov_update_thunks_raw,\n          itertools.cycle(self._cov_devices)):\n\n        cov_variable_thunks.append(_make_thunk_on_device(cov_variable_thunk,\n                                                         device))\n        cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,\n                                                       device))\n    else:\n      cov_variable_thunks = cov_variable_thunks_raw\n      cov_update_thunks = cov_update_thunks_raw\n\n    inv_variable_thunks = inv_variable_thunks_raw\n\n    if self._inv_devices:\n      inv_update_thunks = []\n      for inv_update_thunk, device in zip(inv_update_thunks_raw,\n                                          itertools.cycle(self._inv_devices)):\n        inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,\n                                                       device))\n    else:\n      inv_update_thunks = inv_update_thunks_raw\n\n    return (cov_variable_thunks, cov_update_thunks,\n            inv_variable_thunks, inv_update_thunks)\n\n\nclass ReplicaRoundRobinPlacementMixin(object):\n  \"\"\"Implements round robin placement strategy for certain ops on replicas.\n\n  This placement strategy can be used in certain TPU training systems, where\n  there are multiple \"replicas\" of the graph, such as in TPUEstimator or\n  TF-Replicator. The execution of inverse and transformation ops, which by\n  default occurs redundantly on all replicas, are instead distributed over\n  replicas in a round-robin fashion. This is achieved by using tf.cond\n  statements to check the replica id number.\n\n  This placement strategy doesn't need to be used with TPU training, and may\n  not work with all possible setups (such as TF Replicator). When it does work\n  however, it may provide a substantial improvement in wall-clock time.\n  \"\"\"\n\n  def __init__(self, distribute_transformations=True, **kwargs):\n    \"\"\"Create a ReplicaRoundRobinPlacementMixin object.\n\n    Args:\n      distribute_transformations: Bool. If True we distribute certain vector\n        transformations, such as multiplication by the preconditioner, across\n        different replicas. Because this is a cheaper operation it may not\n        always be worth the increase communication cost to do this.\n        (Default: True)\n      **kwargs: Pass through arguments.\n    \"\"\"\n\n    if not utils.is_replicated():\n      raise ValueError(\"This placement mode should only be used with certain \"\n                       \"kinds of 'replicated' setups, such as TPUEstimator \"\n                       \"or TF-Replicator.\")\n\n    self._distribute_transformations = distribute_transformations\n\n    super(ReplicaRoundRobinPlacementMixin, self).__init__(**kwargs)\n\n  def _place_and_compute_transformation_thunks(self, thunks, params_list):\n    \"\"\"Computes transformation thunks with round-robin replica placement.\n\n    Replica placement done in round-robin fashion according to the order of\n    the `blocks` property, cycling through the replicas in numerical order.\n\n    Args:\n      thunks: A list of thunks to run. Must be in one to one correspondence\n        with the `blocks` property.\n      params_list: A list of the corresponding parameters. Must be in one to one\n        correspondence with the `blocks` property.\n\n    Returns:\n      A list (in the same order) of the returned results of the thunks, with\n      round-robin replica placement applied.\n    \"\"\"\n    del params_list\n\n    return utils.map_gather(thunks)\n\n  def create_ops_and_vars_thunks(self, scope=None):\n    \"\"\"Create op/var-making thunks with replica placement for inverse ops.\n\n    For each factor in the list of factors, the associated inverse ops will\n    execute on a single replica which is chosen in round-robin fashion.\n\n    Cov ops are run on all replicas, with the appropriate averaging done by\n    using a few cross_replica_mean's that have been injected into the\n    FisherFactor classes (and execute regardless if this mixin is being used).\n\n    This function returns 4 lists of thunks: cov_variable_thunks,\n    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.\n\n    The length of each list is the number of factors and the i-th element of\n    each list corresponds to the i-th factor (given by the \"factors\" property).\n    (Actually, for inv_update_thunks this class in particular returns only one\n    thunk inside inv_update_thunks that updates all the factors.)\n\n    Note that the execution of these thunks must happen in a certain\n    partial order.  The i-th element of cov_variable_thunks must execute\n    before the i-th element of cov_update_thunks (and also the i-th element\n    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks\n    must execute before the i-th element of inv_update_thunks.\n\n    TL;DR (oversimplified): Execute the thunks according to the order that\n    they are returned.\n\n    Args:\n      scope: A string or None.  If None it will be set to the name of this\n        estimator (given by the name property). All variables will be created,\n        and all thunks will execute, inside of a variable scope of the given\n        name. (Default: None)\n\n    Returns:\n      cov_variable_thunks: A list of thunks that make the cov variables.\n      cov_update_thunks: A list of thunks that make the cov update ops.\n      inv_variable_thunks: A list of thunks that make the inv variables.\n      inv_update_thunks: A list of thunks that make the inv update ops.\n    \"\"\"\n\n    (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,\n     inv_update_thunks_raw) = self._create_ops_and_vars_thunks(scope=scope)\n\n    cov_variable_thunks = cov_variable_thunks_raw\n\n    # all_averages of cov values are performed internally in the FisherFactor\n    # classes, so we don't need to do anything for the cov updates here.\n    cov_update_thunks = cov_update_thunks_raw\n    inv_variable_thunks = inv_variable_thunks_raw\n\n    # The thunks made here execute the supplied inverse update thunk and then\n    # retrieve the values from the corresponding inverse variables.\n    def make_thunk(inv_update_thunk, inv_vars):\n      def thunk():\n        with tf.control_dependencies([inv_update_thunk()]):\n          return nest.map_structure(tf.identity, inv_vars)\n      return thunk\n\n    # This single thunk calls map_gather to distribute the work, and then\n    # saves the results back to the corresponding inverse variables.\n    def inv_update_thunk():\n\n      assert len(inv_update_thunks_raw) == len(self.factors)\n\n      # Create a list of factors and thunks that only include the factors\n      # that have inverse variables.  Note that not executing the inverse ops of\n      # those that don't shouldn't matter.\n      factors_and_thunks = tuple(\n          (factor, thunk)\n          for factor, thunk in zip(self.factors, inv_update_thunks_raw)\n          if factor.get_inv_vars())\n\n      factors, _ = zip(*factors_and_thunks)\n\n      thunks = tuple(\n          make_thunk(inv_update_thunk, factor.get_inv_vars())\n          for factor, inv_update_thunk in factors_and_thunks)\n\n      results = utils.map_gather(thunks)\n\n      # These assigns save the values back to the variables.\n      ops = (utils.smart_assign(var, val)\n             for factor, result in zip(factors, results)  # pylint: disable=g-complex-comprehension\n             for val, var in zip(result, factor.get_inv_vars()))\n      return tf.group(*ops)\n\n    # Note that we have to return one big inv_update_thunk instead of one for\n    # each factor. This is because utils.map_gather doesn't support returning\n    # thunks (because TFReplicator's map_gather doesn't).\n    inv_update_thunks = [inv_update_thunk]\n\n    return (cov_variable_thunks, cov_update_thunks,\n            inv_variable_thunks, inv_update_thunks)\n\n"
  },
  {
    "path": "kfac/python/ops/tensormatch/__init__.py",
    "content": ""
  },
  {
    "path": "kfac/python/ops/tensormatch/graph_matcher.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Pattern matcher for TensorFlow graphs in the Python object model.\n\nWriting Python to crawl through TensorFlow graphs can be a pain, and the\nresulting code is often hard to adapt, extend, and reuse. Instead of\nhand-writing that code, we should automatically generate it from a simple\npattern-matching language. This package provides one such system.\n\nMore precisely, this package defines a pattern language for matching and\nextracting nodes from TensorFlow graphs as represented in the Python object\nmodel. Patterns can be defined in Python code with a simple syntax and are\ncompiled into compositions of continuation-passing matcher combinators. The\nmechanism for compiling the pattern language into combinators looks like an\nanalyzing Scheme interpreter. The design comes from GJS's 6.945 at MIT.\n\nThe pattern language compiler can be extended by registering new handlers at\nruntime, and new pattern compilers can be made by instantiating the\nPatternEvaluator class.\n\nThe grammar for the pattern language implemented in this file is:\n\n  pattern ::= element | choice | list | internal_node | negated_pattern | any\n  patterns ::= pattern, patterns | ()\n\n  element ::= ('?', element_name, restrictions)\n  element_name ::= PYTHON_STRING\n  restrictions ::= PYTHON_FUNCTION, restrictions | ()\n\n  choice ::= ('?:choice', patterns)\n\n  list ::= ('List', patterns)\n\n  internal_node ::= (pattern, neighbor_constraints)\n  neighbor_constraints ::= input_list | output_list | input_list, output_list\n  input_list ::= ('In', patterns)\n  output_list ::= ('Out', patterns)\n\n  negated_pattern ::= ('?:not', pattern)\n\n  any ::= ('?:any',)\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow.python.util import tf_inspect\n\nfrom kfac.python.ops.tensormatch import tensorflow_graph_util as util\n\n\ndef _any(itr):\n  \"\"\"Similar to Python's any, but returns the first value that matches.\"\"\"\n  for val in itr:\n    if val:\n      return val\n  return False\n\n\ndef _all(itr):\n  \"\"\"Similar to Python's all, but returns the first value that doesn't match.\"\"\"\n  any_iterations = False\n  val = None\n  for val in itr:\n    any_iterations = True\n    if not val:\n      return val\n  return val if any_iterations else True\n\n\ndef is_seq(obj):\n  return isinstance(obj, (tuple, list))\n\n\ndef is_nonempty_seq(obj):\n  return is_seq(obj) and bool(obj)\n\n\ndef is_empty_seq(obj):\n  return is_seq(obj) and not bool(obj)\n\n\n## define the syntax of the pattern language\n\n\nis_pattern = is_nonempty_seq\n\n\ndef is_element_pattern(pat):\n  return is_pattern(pat) and pat[0] == '?'\n\n\ndef element_name(pat):\n  return pat[1]\n\n\ndef element_restrictions(pat):\n  return pat[2:]\n\n\ndef is_choice_pattern(pat):\n  return is_pattern(pat) and pat[0] == '?:choice'\n\n\ndef choice_patterns(pat):\n  return pat[1:]\n\n\ndef is_list_pattern(pat):\n  return is_pattern(pat) and pat[0] == 'List'\n\n\ndef list_patterns(pat):\n  return pat[1:]\n\n\ndef is_not_pattern(pat):\n  return is_pattern(pat) and pat[0] == '?:not'\n\n\ndef negated_pattern(pat):\n  return pat[1]\n\n\ndef is_any_pattern(pat):\n  return is_pattern(pat) and pat[0] == '?:any'\n\n\ndef is_any_noconsume_pattern(pat):\n  return is_pattern(pat) and pat[0] == '?:any_noconsume'\n\n\ndef is_internal_node_pattern(pat):\n  def is_neighbor_constraints(lst):\n    tags = tuple(item[0] for item in lst)\n    return tags in {('In',), ('Out',), ('In', 'Out')}\n  return (is_pattern(pat) and all(is_pattern(item) for item in pat)\n          and is_neighbor_constraints(pat[1:]))\n\n\ndef internal_node_pattern(pat):\n  return pat[0]\n\n\ndef internal_node_input_pattern(pat):\n  for item in pat[1:]:\n    if item[0] == 'In':\n      return ('List',) + tuple(item[1:])\n  return ('?:any_noconsume',)\n\n\ndef internal_node_output_pattern(pat):\n  for item in pat[1:]:\n    if item[0] == 'Out':\n      return ('List',) + tuple(item[1:])\n  return ('?:any_noconsume',)\n\n\ndef internal_patterns(pat):\n  return [internal_node_pattern(pat), internal_node_input_pattern(pat),\n          internal_node_output_pattern(pat)]\n\n\n## constructors for pattern-matching combinators\n\n\ndef match_eqv(pattern):\n  def eqv_match(data, bindings, consumed, succeed):\n    return data == pattern and succeed(bindings, consumed | {data})\n  return eqv_match\n\n\ndef match_any(data, bindings, consumed, succeed):\n  try:\n    consumed = consumed | {data}  # pylint: disable=g-no-augmented-assignment\n  except TypeError:\n    consumed = consumed | set(data)  # pylint: disable=g-no-augmented-assignment\n  return succeed(bindings, consumed)\n\n\ndef match_any_noconsume(data, bindings, consumed, succeed):  # pylint: disable=unused-argument\n  # this combinator succeeds (but does not append to the consumed set)\n  # regardless of the value of 'data', though the caller still passes 'data'\n  # (since all combinators have the same signature)\n  return succeed(bindings, consumed)\n\n\ndef match_element(variable_name, restrictions):\n  \"\"\"Matches an element.\"\"\"\n  def element_match(data, bindings, consumed, succeed):\n    consumed = consumed | {data}  # pylint: disable=g-no-augmented-assignment\n    if _all(restriction(data) for restriction in restrictions):\n      if not variable_name:\n        return succeed(bindings, consumed)\n      elif variable_name in bindings:\n        return bindings[variable_name] == data and succeed(bindings, consumed)\n      return succeed(dict(bindings, **{variable_name: data}), consumed)\n    return False\n  return element_match\n\n\ndef match_choice(*match_combinators):\n  def choice_match(data, bindings, consumed, succeed):\n    return _any(matcher(data, bindings, consumed, succeed)\n                for matcher in match_combinators)\n  return choice_match\n\n\ndef match_list(*match_combinators):\n  \"\"\"Matches a list.\"\"\"\n  def list_match(data, bindings, consumed, succeed):\n    return _list_match(data, match_combinators, bindings, consumed, succeed)\n\n  def _list_match(data, matchers, bindings, consumed, succeed):\n    \"\"\"Apply matchers elementwise to a list, collecting bindings sequentially.\n\n    Args:\n      data: The list on which to apply the matcher list.\n      matchers: The corresponding list of matchers to apply, element-by-element.\n      bindings: The dictionary of bindings to be consistent with.\n      consumed: The list of graph nodes consumed so far.\n      succeed: The continuation function to call when there is a match.\n\n    Returns:\n      False if there is no match, or succeed(bindings) if there is one.\n    \"\"\"\n    def match_first_then_subsequent(combinator, datum):\n      return combinator(datum, bindings, consumed, match_subsequent_elements)\n\n    def match_subsequent_elements(bindings, consumed):\n      return _list_match(data[1:], matchers[1:], bindings, consumed, succeed)\n\n    if is_empty_seq(matchers) and is_empty_seq(data):\n      return succeed(bindings, consumed)\n    return (is_nonempty_seq(matchers) and is_nonempty_seq(data)\n            and match_first_then_subsequent(matchers[0], data[0]))\n  return list_match\n\n\ndef match_not(match_combinator):\n  def not_match(data, bindings, consumed, succeed):\n    return (not match_combinator(data, bindings, set(),\n                                 lambda bindings, _: True)\n            and succeed(bindings, consumed))\n  return not_match\n\n\ndef match_internal(*match_combinators):\n  expanded_matcher = match_list(*match_combinators)\n  def internal_node_match(data, bindings, consumed, succeed):\n    try:\n      expanded = [data, util.expand_inputs(data), util.expand_outputs(data)]\n    except ValueError:\n      return False\n    return expanded_matcher(expanded, bindings, consumed, succeed)\n  return internal_node_match\n\n\n## parsing the pattern language into compositions of combinators\n\n\nclass PatternEvaluator(object):\n  \"\"\"Pattern evaluator class.\"\"\"\n\n  def __init__(self, default_operation=None):\n    self.default_operation = default_operation\n    self.handlers = []\n\n  def defhandler(self, predicate, handler):\n    self.handlers.append((predicate, handler))\n\n  def __call__(self, pat):\n    for predicate, handler in self.handlers:\n      if predicate(pat):\n        return handler(pat)\n    if self.default_operation:\n      return self.default_operation(pat)\n    raise ValueError\n\nmake_combinators = PatternEvaluator(match_eqv)\nmake_combinators.defhandler(\n    is_element_pattern,\n    lambda pat: match_element(element_name(pat), element_restrictions(pat)))\nmake_combinators.defhandler(\n    is_list_pattern,\n    lambda pat: match_list(*map(make_combinators, list_patterns(pat))))\nmake_combinators.defhandler(\n    is_choice_pattern,\n    lambda pat: match_choice(*map(make_combinators, choice_patterns(pat))))\nmake_combinators.defhandler(\n    is_not_pattern,\n    lambda pat: match_not(make_combinators(negated_pattern(pat))))\nmake_combinators.defhandler(\n    is_any_pattern,\n    lambda pat: match_any)\nmake_combinators.defhandler(\n    is_any_noconsume_pattern,\n    lambda pat: match_any_noconsume)\nmake_combinators.defhandler(\n    is_internal_node_pattern,\n    lambda pat: match_internal(*map(make_combinators, internal_patterns(pat))))\n\n\n## utility function so the patterns require fewer parentheses\n\n\ndef expand_thunks(pat):\n  \"\"\"Expands thunks (zero-argument functions) in a pattern by calling them.\n\n  Args:\n    pat: The pattern to expand, possibly containing thunks.\n\n  Returns:\n    The expanded pattern.\n  \"\"\"\n  def is_thunk(x):\n    if hasattr(x, '__call__'):\n      spec = tf_inspect.getargspec(x)\n      num_free_args = len(set(spec.args)) - len(set(spec.defaults or {}))\n      return num_free_args == 0\n    return False\n  while is_thunk(pat):\n    pat = pat()\n  if isinstance(pat, (tuple, list)):\n    return type(pat)(map(expand_thunks, pat))\n  return pat\n\n\n## main matcher interface functions\n\n\ndef matcher(pattern):\n  combinators = make_combinators(expand_thunks(pattern))\n  def match(node):\n    return combinators(node, {}, set(), lambda bindings, _: bindings or True)\n  return match\n\n\ndef all_matcher(pattern):\n  combinators = make_combinators(expand_thunks(pattern))\n  results = []\n\n  def all_matches(node):\n    combinators(node, {}, set(),\n                lambda bindings, _: results.append(bindings or True))\n    return results\n\n  return all_matches\n\n\ndef matcher_with_consumed(pattern):\n  combinators = make_combinators(expand_thunks(pattern))\n  def match(node):\n    return combinators(node, {}, set(),\n                       lambda bindings, consumed: (bindings, consumed))\n  return match\n"
  },
  {
    "path": "kfac/python/ops/tensormatch/graph_patterns.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Convenience functions for writing patterns in Python code..\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow.compat.v1 as tf\n\nfrom kfac.python.ops.tensormatch import tensorflow_graph_util as util\n\n\n## patterns\n\n\ndef Op(name=None):\n  return ('?', name, util.is_op)\n\n\ndef Tensor(name=None):\n  return ('?', name, util.is_tensor)\n\n\ndef Variable(name=None):\n  return ('?', name, util.is_var)\n\n\ndef Const(name=None):\n  return ('?', name, util.is_const)\n\n\ndef Placeholder(name=None):\n  return ('?', name, util.is_placeholder)\n\n\nutil.import_ops_no_clobber(globals(), dir(tf.raw_ops))\n\n# NOTE(mattjj): renamed in TF 1.0, but not registered as an op in 1.0.1\nUnstack = util.make_op_pattern('Unpack')  # pylint: disable=invalid-name\n\n## convenient compound patterns\n\n# The op definitions are pulled in via the op_def_registry, which is\n# why we disable the undefined variable check for e.g. Rsqrt, Mul, etc.\n# Otherwise we would have to refer to them by name rather than object.\n# pylint: disable=undefined-variable\n\n\ndef BatchNorm(in_pattern=Tensor('in'),\n              scale_name='scale',\n              offset_name='offset',\n              output_name='out'):\n  \"\"\"Pattern constructor for matching tf.nn.batch_normalization subgraphs.\"\"\"\n  inv_pat = (Tensor('inv'), ('In', ('?:choice', Rsqrt,\n                                    (Mul, ('In', (Tensor, ('In', Rsqrt)),\n                                           Tensor(scale_name))))))\n  without_offset_pat = (Mul, ('In', Tensor, Tensor('inv')))\n  with_offset_pat = (Sub, ('In', Tensor(offset_name),\n                           (Tensor, ('In', (Mul, ('In', Tensor,\n                                                  Tensor('inv')))))))\n  return (Tensor(output_name),\n          ('In', (AddV2, ('In', (Tensor, ('In', (Mul, ('In', in_pattern,\n                                                       inv_pat)))),\n                          (Tensor, ('In', ('?:choice', with_offset_pat,\n                                           without_offset_pat)))))))\n\n\ndef FusedBatchNormOutput(in_pattern=Tensor('in'),\n                         scale_name='scale',\n                         offset_name='offset',\n                         output_name='out'):\n  \"\"\"Pattern constructor for matching tf.nn.fused_batch_norm subgraphs.\"\"\"\n  return (Tensor(output_name),\n          ('In',\n           (('?:choice', FusedBatchNorm, FusedBatchNormV2, FusedBatchNormV3),\n            ('In', in_pattern, Tensor(scale_name), Tensor(offset_name), Tensor,\n             Tensor))))\n\n\n# TODO(mattjj): add more ops to this pattern\nNonlinearity = ('?:choice', Relu, Tanh)  # pylint: disable=invalid-name\n\n\ndef ScaleAndShift(in_pattern=Tensor('in'),\n                  scale_name='scale',\n                  shift_name='shift',\n                  output_name='out'):\n  \"\"\"Pattern constructor for matching scale & shift operation subgraphs.\"\"\"\n\n  scale_pat_r = (Mul, ('In', in_pattern, Variable(scale_name)))\n  scale_pat_l = (Mul, ('In', Variable(scale_name), in_pattern))\n\n  scale_pat = ('?:choice', scale_pat_r, scale_pat_l)\n\n  pat_r = (('?:choice', Add, AddV2),\n           ('In', (Tensor, ('In', scale_pat)), Variable(shift_name)))\n  pat_l = (('?:choice', Add, AddV2),\n           ('In', Variable(shift_name), (Tensor, ('In', scale_pat))))\n\n  return (Tensor(output_name), ('In', ('?:choice', pat_r, pat_l, scale_pat)))\n\n\ndef Affine(in_pattern=Tensor('in'),\n           linear_op_name='linear_op',\n           weights_name='weights',\n           biases_name='biases',\n           output_name='pre_activations'):\n  \"\"\"Pattern constructor for matching affine operation subgraphs.\"\"\"\n  linear_pat = (('?:choice', Conv2D(linear_op_name), MatMul(linear_op_name),\n                 BatchMatMulV2(linear_op_name)),\n                ('In', in_pattern, Variable(weights_name)))\n  affine_pat_r = (('?:choice', Add, BiasAdd, AddV2),\n                  ('In', (Tensor, ('In', linear_pat)), Variable(biases_name)))\n  affine_pat_l = (('?:choice', Add, BiasAdd, AddV2),\n                  ('In', Variable(biases_name), (Tensor, ('In', linear_pat))))\n  affine_pat = ('?:choice', affine_pat_r, affine_pat_l)\n  return (Tensor(output_name), ('In', ('?:choice', affine_pat, linear_pat)))\n\n\ndef Embed(in_pattern=Tensor('in'),\n          linear_op_name='linear_op',\n          weights_name='weights',\n          axis_name='axis',\n          output_name='pre_activations'):\n  \"\"\"Pattern constructor for matching embedding layer subgraphs.\"\"\"\n  embed_v1 = (('?:choice', Gather(linear_op_name),\n               ResourceGather(linear_op_name)),\n              ('In', Variable(weights_name), in_pattern))\n  embed_v2 = (GatherV2(linear_op_name),\n              ('In', Variable(weights_name), in_pattern, Tensor(axis_name)))\n  embed = ('?:choice', embed_v1, embed_v2)\n\n  return (Tensor(output_name), ('In', embed))\n\n\n# Only used in tests:\ndef Layer(in_pattern=Tensor('in'), **kwargs):\n  \"\"\"Pattern constructor for matching a basic layer.\"\"\"\n  return (Tensor('activations'), ('In', (Nonlinearity, ('In', Affine(\n      in_pattern, **kwargs)))))\n\n\n# Only used in tests:\ndef LayerWithBatchNorm(in_pattern=Tensor('in')):\n  \"\"\"Pattern constructor for matching a layer with batch normalization.\"\"\"\n  return (Tensor('final_activations'),\n          ('In', (Nonlinearity, ('In', BatchNorm(Affine(in_pattern))))))\n\n\n# pylint: enable=undefined-variable\n"
  },
  {
    "path": "kfac/python/ops/tensormatch/graph_search.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Functions for automatically registering network layers for K-FAC.\"\"\"\nimport collections\nfrom absl import logging\nimport enum\nimport tensorflow.compat.v1 as tf\n\n\nfrom tensorflow.python.framework import ops as tf_ops\nfrom tensorflow.python.ops import resource_variable_ops\nfrom kfac.python.ops import utils\nfrom kfac.python.ops.tensormatch import graph_matcher as gm\nfrom kfac.python.ops.tensormatch import graph_patterns as gp\nfrom kfac.python.ops.tensormatch import tensorflow_graph_util as graph_utils\n\n\nclass RecordType(enum.Enum):\n  fully_connected = 1\n  conv2d = 2\n  scale_and_shift = 3\n  batch_norm = 4\n\n\nclass AmbiguousRegistrationError(Exception):\n  pass\n\n\nclass MatchRecord(object):\n  \"\"\"An object for storing data about graph pattern matches.\"\"\"\n\n  def __init__(self, record_type, params, tensor_set, data=None):\n    \"\"\"Construct a new `Record` object.\n\n    Args:\n      record_type: A `RecordType` representing the type of layer being recorded.\n      params: A list of the variables used by this layer.\n      tensor_set: A set of all tensors matched by the pattern. This is used\n        for determining when one match is a subset of another.\n      data: An optional dict for storing attributes specific to certain\n        record types.\n    \"\"\"\n    self.record_type = record_type\n    self.params = params\n    self.tensor_set = tensor_set\n    if data is None:\n      data = dict()\n    self.data = data\n\n\ndef ensure_sequence(obj):\n  \"\"\"If `obj` isn't a tuple or list, return a tuple containing `obj`.\"\"\"\n  if isinstance(obj, (tuple, list)):\n    return obj\n  else:\n    return (obj,)\n\n\ndef record_affine_from_bindings(bindings, consumed_tensors,\n                                tensors_to_variables):\n  \"\"\"Construct a MatchRecord for the given Affine pattern bindings.\n\n  Args:\n    bindings: A dict representing a matched pattern. Strings representing\n      components of the pattern are mapped to the matched Tensors.\n    consumed_tensors: A set of all tensors consumed by the matched pattern.\n      This should be a superset of the values of the bindings dict.\n    tensors_to_variables: A dict mapping Tensors to the variables referencing\n      them.\n\n  Returns:\n    A `MatchRecord` containing the information necessary to register the layer.\n\n  Raises:\n    ValueError: If the bindings contain biases but not weights.\n  \"\"\"\n  if 'biases' in bindings:\n    biases = tensors_to_variables.get(bindings['biases'])\n  else:\n    biases = None\n  weights = tensors_to_variables.get(bindings['weights'], None)\n  inputs = bindings['in']\n  outputs = bindings['pre_activations']\n  linear_op = bindings['linear_op']\n\n  if biases is not None and weights is None:\n    raise ValueError(\"Can't register linear layer part with only biases.\")\n\n  if weights is not None and biases is not None:\n    params = (weights, biases)\n  else:\n    params = weights\n\n  if params is not None:\n    record_data = dict(inputs=inputs, outputs=outputs)\n\n    is_sparse = (linear_op.type == 'Gather'\n                 or linear_op.type == 'GatherV2'\n                 or linear_op.type == 'ResourceGather')\n\n    if (linear_op.type == 'MatMul' or linear_op.type == 'BatchMatMulV2' or\n        is_sparse):\n      record_type = RecordType.fully_connected\n\n      if len(inputs.shape) >= 4 or (is_sparse and len(inputs.shape) >= 3):\n        raise ValueError('K-FAC currently doesn''t support multi-use/temporal '\n                         'fully-connected layers with more than two batch/time '\n                         'dimensions. Two is the max, and they must be in the '\n                         'order [time, batch, ...]. Found this for params {} '\n                         'and op {}.'.format(params, repr(linear_op)))\n\n      if ((linear_op.type == 'MatMul'\n           and (linear_op.get_attr('transpose_a')\n                or linear_op.get_attr('transpose_b'))) or\n          (linear_op.type == 'BatchMatMulV2'\n           and (linear_op.get_attr('adj_x')\n                or linear_op.get_attr('adj_y')))):\n        raise ValueError('K-FAC currently doesn''t support fully-connected '\n                         'layers with transposed inputs or weights as part of '\n                         'the actual op. Found this for params {} and '\n                         'op {}.'.format(params, repr(linear_op)))\n\n      record_data['dense_inputs'] = not is_sparse\n\n    elif linear_op.type == 'Conv2D':\n      record_type = RecordType.conv2d\n      strides = tuple(map(int, linear_op.get_attr('strides')))\n      padding = linear_op.get_attr('padding')\n      data_format = linear_op.get_attr('data_format')\n      # In Python 3 this might be class \"bytes\" so we convert to string.\n      if not isinstance(padding, str):\n        padding = padding.decode()\n      if not isinstance(data_format, str):\n        data_format = data_format.decode()\n      record_data['strides'] = strides\n      record_data['padding'] = padding\n      record_data['data_format'] = data_format\n\n    else:\n      raise ValueError(\"Can't register operation: {}\".format(repr(linear_op)))\n\n    return MatchRecord(\n        record_type=record_type,\n        params=params,\n        tensor_set=consumed_tensors,\n        data=record_data)\n\n\ndef record_scale_and_shift_from_bindings(bindings, consumed_tensors,\n                                         tensors_to_variables):\n  \"\"\"Construct a MatchRecord for the given ScaleAndShift pattern bindings.\n\n  Args:\n    bindings: A dict representing a matched pattern. Strings representing\n      components of the pattern are mapped to the matched Tensors.\n    consumed_tensors: A set of all tensors consumed by the matched pattern.\n      This should be a superset of the values of the bindings dict.\n    tensors_to_variables: A dict mapping Tensors to the variables referencing\n      them.\n\n  Returns:\n    A `MatchRecord` containing the information necessary to register the layer.\n  \"\"\"\n  if 'shift' in bindings:\n    shift = tensors_to_variables.get(bindings['shift'])\n  else:\n    shift = None\n  scale = tensors_to_variables.get(bindings['scale'], None)\n\n  inputs = bindings['in']\n  outputs = bindings['out']\n\n  # I'm not sure if this can ever actually happen.\n  if shift is not None and scale is None:\n    raise ValueError(\"Can't register scale_and_shift with only shift.\")\n\n  if scale is not None and shift is not None:\n    params = (scale, shift)\n  else:\n    params = scale\n\n  if params is not None:\n    record_data = dict(inputs=inputs, outputs=outputs)\n\n    return MatchRecord(\n        record_type=RecordType.scale_and_shift,\n        params=params,\n        tensor_set=consumed_tensors,\n        data=record_data)\n\n\ndef record_batch_norm_from_bindings(bindings, consumed_tensors,\n                                    tensors_to_variables):\n  \"\"\"Construct a MatchRecord for the given BatchNorm pattern bindings.\n\n  Args:\n    bindings: A dict representing a matched pattern. Strings representing\n      components of the pattern are mapped to the matched Tensors.\n    consumed_tensors: A set of all tensors consumed by the matched pattern.\n      This should be a superset of the values of the bindings dict.\n    tensors_to_variables: A dict mapping Tensors to the variables referencing\n      them.\n\n  Returns:\n    A `MatchRecord` containing the information necessary to register the layer.\n  \"\"\"\n  if 'offset' in bindings:\n    offset = tensors_to_variables.get(bindings['offset'])\n  else:\n    offset = None\n\n  if 'scale' in bindings:\n    scale = tensors_to_variables.get(bindings['scale'])\n  else:\n    scale = None\n\n  inputs = bindings['in']\n  outputs = bindings['out']\n\n  if scale is not None and offset is not None:\n    params = (scale, offset)\n  elif scale is not None:\n    params = scale\n  elif offset is not None:\n    params = offset\n  else:\n    params = None\n\n  if params is not None:\n    record_data = dict(inputs=inputs, outputs=outputs)\n\n    return MatchRecord(\n        record_type=RecordType.batch_norm,\n        params=params,\n        tensor_set=consumed_tensors,\n        data=record_data)\n\n\ndef register_layers(layer_collection, varlist, batch_size=None):\n  \"\"\"Walk the graph and register all layers to layer_collection.\n\n  Parameters used multiple times in the graph need to be handled differently\n  depending on context: this could either mean the parameters represent an\n  RNN layer, or that the graph has been replicated as multiple \"towers\"\n  to allow data parallelism.\n  We differentiate these cases by examining the loss functions registered by\n  layer_collection: if losses have been registered multiple times with\n  reuse=True, we separate the subgraphs corresponding to each tower and\n  register layers independently for each with reuse=True.\n\n  Args:\n    layer_collection: A `LayerCollection` to use for registering layers.\n    varlist: A list of the variables in the graph.\n    batch_size: A `int` representing the batch size. Needs to specified if\n      registering generic variables that don't match any layer patterns or\n      if time/uses is folded. If the time/uses dimension is merged with\n      batch then this is used to infer number of uses/time-steps. NOTE: In the\n      replicated context this must be the per-replica batch size, and not\n      the total batch size.\n\n  Returns:\n    A `dict` of the entries registered to layer_collection.fisher_blocks.\n\n  Raises:\n    ValueError: If not all losses were registered the same number of times.\n      If any variables specified as part of linked groups were not\n      matched with their group.\n      If the same variable is used in multiple layers types\n      (e.g. fully connected and 2d convolution), or if the same variable is\n      used in multiple layers of a type that doesn't support shared parameters.\n    AmbiguousRegistrationError: If any variables must be registered as generic\n      and batch_size is not specified, or if even after filtering, there are\n      matches with overlapping but unequal sets of variables (see\n      filter_records).\n  \"\"\"\n  original_fisher_blocks = layer_collection.fisher_blocks.copy()\n  user_registered_variables = set()\n  for params in layer_collection.fisher_blocks.keys():\n    for variable in ensure_sequence(params):\n      user_registered_variables.add(variable)\n  user_registered_variables = frozenset(user_registered_variables)\n\n  if not layer_collection.losses:\n    raise ValueError('No registered losses found. Automatic registration '\n                     'requires all losses in the graph to be registered before '\n                     'it can begin.')\n  else:\n    inputs_by_loss = tuple(tuple(loss.inputs for loss in loss_list)\n                           for loss_list in layer_collection.towers_by_loss)\n\n    num_towers = len(inputs_by_loss[0])\n\n    if not all(\n        (len(input_tensors) == num_towers for input_tensors in inputs_by_loss)):\n      raise ValueError(\n          'If losses are registered with reuse=True, each name must be '\n          'registered the same number of times.')\n\n    for tower_number, tower_input_tensors in enumerate(zip(*inputs_by_loss)):\n      reuse = (tower_number > 0)\n      with tf.variable_scope('tower_%d' % tower_number, reuse=reuse):\n        subgraph = utils.SubGraph(tower_input_tensors)\n        register_subgraph_layers(\n            layer_collection,\n            varlist,\n            subgraph,\n            user_registered_variables=user_registered_variables,\n            reuse=reuse,\n            batch_size=batch_size)\n\n  fisher_blocks = layer_collection.fisher_blocks\n  return {\n      params: fisher_blocks[params]\n      for params in set(fisher_blocks) - set(original_fisher_blocks)\n  }\n\n\ndef register_subgraph_layers(layer_collection,\n                             varlist,\n                             subgraph,\n                             user_registered_variables=frozenset(),\n                             reuse=False,\n                             batch_size=None):\n  \"\"\"Walk a subgraph and register all layers to layer_collection.\n\n  Args:\n    layer_collection: A `LayerCollection` to use for registering layers.\n    varlist: A list of the variables in the graph.\n    subgraph: The `SubGraph` to search.\n    user_registered_variables: A set of all the variables the user has manually\n      registered. No layers using any of these variables should be registered.\n    reuse: (OPTIONAL) bool. If True, then `layer_collection`\n      selects a previously registered block with the same key as the key\n      derived from `params` of that block. If False, a new block is\n      registered.\n    batch_size: A `int` representing the batch size. Needs to specified if\n      registering generic variables that don't match any layer patterns or\n      if the time/uses dimension is folded into batch. If the time/uses\n      dimension is merged with batch then this is used to infer number of\n      uses/time-steps.\n\n  Raises:\n    ValueError: If any variables specified as part of linked groups were not\n      matched with their group.\n      If the same variable is used in multiple layers types\n      (e.g. fully connected and 2d convolution), or if the same variable is\n      used in multiple layers of a type that doesn't support shared parameters.\n    AmbiguousRegistrationError: If any variables must be registered as generic\n      and batch_size is not specified, or if even after filtering, there are\n      matches with overlapping but unequal sets of variables (see\n      filter_records).\n  \"\"\"\n\n  # List of patterns and binding functions to use when we match one of them\n  match_register_list = [(gm.matcher_with_consumed(gp.Affine),\n                          record_affine_from_bindings),\n                         (gm.matcher_with_consumed(gp.ScaleAndShift),\n                          record_scale_and_shift_from_bindings),\n                         (gm.matcher_with_consumed(gp.BatchNorm),\n                          record_batch_norm_from_bindings),\n                         (gm.matcher_with_consumed(gp.FusedBatchNormOutput),\n                          record_batch_norm_from_bindings),\n                         (gm.matcher_with_consumed(gp.Embed),\n                          record_affine_from_bindings)]\n\n  # Patterns return bindings to raw tensors, so we need to be able to map back\n  # to variables from the tensors those variables reference.\n  def var_to_tensors(var):\n    if resource_variable_ops.is_resource_variable(var):\n      if tf.control_flow_v2_enabled() and hasattr(layer_collection.graph,\n                                                  'captures'):\n        # TODO(b/143690035): Note that the \"captures\" property relies on an\n        # API which might change.\n        captures = layer_collection.graph.captures\n        return [h for vh, h in captures if vh is var.handle]\n      else:\n        return [var.handle]\n    if utils.is_reference_variable(var):\n      return [tf_ops.internal_convert_to_tensor(var, as_ref=True)]\n    raise ValueError('%s is not a recognized variable type.' % str(var))\n\n  tensors_to_variables = {tensor: var for var in varlist\n                          for tensor in var_to_tensors(var)}\n\n  # Get all the ops from the graph.\n  ops = layer_collection.graph.get_operations()\n\n  # Filter out tf.identity ops since otherwise the matcher generates spurious\n  # matches.\n  ops = tuple(op for op in ops if not graph_utils.is_identity(op))\n\n  # Extract out the output tensors from the ops\n  tensors = tuple(out for op in ops for out in op.outputs)\n\n  # Filter the tensors to include only those in the subgraph.\n  tensors = subgraph.filter_list(tensors)\n\n  # Go through each tensor and try to match each pattern to it.\n  record_list_dict = dict()\n  for tensor in tensors:\n    for match, recfunc in match_register_list:\n      match_res = match(tensor)\n      if match_res:\n        bindings, consumed_tensors = match_res\n        record = recfunc(bindings, consumed_tensors, tensors_to_variables)\n        if record is not None:\n          if record.params not in record_list_dict:\n            record_list_dict[record.params] = []\n          record_list_dict[record.params].append(record)\n\n  # Filter out records violating any rules.\n  record_list_dict = filter_records(layer_collection, record_list_dict,\n                                    user_registered_variables)\n\n  # Register the layers by going through the lists of records for each param.\n  register_records(layer_collection, record_list_dict, reuse, batch_size)\n\n  # Determine which variables were registered either by the user or\n  # in the current call to register_subgraph_layers.\n  automatically_registered_variables = {\n      var\n      for params in record_list_dict\n      for var in ensure_sequence(params)\n  }\n  registered_variables = (\n      automatically_registered_variables | user_registered_variables)\n\n  # Register any remaining parameters generically.\n  for variable in varlist:\n    if variable not in registered_variables:\n      for specified_grouping in layer_collection.linked_parameters:\n        assert isinstance(specified_grouping, frozenset)\n        if variable in specified_grouping and len(specified_grouping) > 1:\n          raise ValueError(\n              'Variable {} in linked group {} was not matched.'.format(\n                  variable, specified_grouping))\n\n      generic_bad_string = ('generic registrations may be a symptom that the '\n                            'scanner is failing to auto-detect your model. '\n                            'Generic uses a last-resort approximation, and '\n                            'should never be used for common layer types that '\n                            'K-FAC properly supports, such as convs or '\n                            'fully-connected layers.')\n      if batch_size is None:\n        raise AmbiguousRegistrationError(\n            ('Tried to register {} as generic without knowledge of batch_size. '\n             'You can pass batch_size in to fix this error. But please note, '\n             + generic_bad_string).format(variable))\n      logging.warning(('Registering {} as generic because graph scanner '\n                       'couldn\\'t match a pattern for it. This can sometimes '\n                       'be caused by the variable not being present in the '\n                       'graph terminating at the registered losses. You might '\n                       'need to pass an explicit list of parameters to tell '\n                       'the system what parameters are actually in your model. '\n                       'Note that ' + generic_bad_string).format(variable))\n      layer_collection.register_generic(variable, batch_size, reuse=reuse)\n\n\ndef filter_user_registered_records(record_list_dict, user_registered_variables):\n  \"\"\"Remove any matches that contain a variable registered by the user.\"\"\"\n  record_list_dict = record_list_dict.copy()\n  for params in list(record_list_dict.keys()):\n    for variable in ensure_sequence(params):\n      if variable in user_registered_variables:\n        del record_list_dict[params]\n        break\n  return record_list_dict\n\n\ndef filter_grouped_variable_records(layer_collection, record_list_dict):\n  \"\"\"Remove any matches violating user specified parameter groupings.\"\"\"\n  record_list_dict = record_list_dict.copy()\n  for params in list(record_list_dict.keys()):\n    for specified_grouping in layer_collection.linked_parameters:\n      param_set = set(ensure_sequence(params))\n      assert isinstance(specified_grouping, frozenset)\n      if (param_set.intersection(specified_grouping) and\n          param_set != specified_grouping):\n        del record_list_dict[params]\n        break\n  return record_list_dict\n\n\ndef filter_subgraph_records(record_list_dict):\n  \"\"\"Remove any matches that correspond to strict subgraphs of other matches.\"\"\"\n\n  # Flatten the records dict to compare records with different parameters.\n  flat_record_list = [\n      record for records in record_list_dict.values() for record in records\n  ]\n\n  # Compare all pairs of records that share any variables. We perform two\n  # passes, first marking variables for deletion by adding them to a set and\n  # then removing all marked variables, in order to avoid traversing\n  # flat_record_list on every removal while still maintaining record order.\n  records_by_variable = collections.defaultdict(list)\n  for record in flat_record_list:\n    for variable in ensure_sequence(record.params):\n      records_by_variable[variable].append(record)\n  records_to_remove = set()\n  for record in flat_record_list:\n    for variable in ensure_sequence(record.params):\n      for other_record in records_by_variable[variable]:\n        if record.tensor_set < other_record.tensor_set:\n          records_to_remove.add(record)\n  flat_record_list = [\n      record for record in flat_record_list if record not in records_to_remove\n  ]\n\n  # Unflatten the records list.\n  record_list_dict = collections.defaultdict(list)\n  for record in flat_record_list:\n    record_list_dict[record.params].append(record)\n    assert record is not None\n  return dict(record_list_dict)\n\n\ndef filter_records(layer_collection, record_list_dict,\n                   user_registered_variables):\n  \"\"\"Filter out recorded matches based on a set of rules.\n\n  A match should be filtered out if any of the following are true:\n    1. It contains any variables already registered by the user.\n    2. It violates the user specified variable groupings.\n    3. It corresponds to a strict subgraph of another match not already filtered\n       out by the above steps.\n\n  Args:\n    layer_collection: A `LayerCollection` to use for registering layers.\n    record_list_dict: A dict mapping tuples of variables to lists of\n      `MatchRecord`s representing all of the places those variables are used\n      in the graph.\n    user_registered_variables: A set of all the variables the user has manually\n      registered. No layers using any of these variables should be registered.\n\n  Returns:\n    A copy of `record_list_dict` with the records violating rules filtered out.\n\n  Raises:\n    AmbiguousRegistrationError: If even after filtering, there are matches\n      with overlapping but unequal sets of variables. In these cases, the user\n      will need to either manually register layers that use these variables,\n      or specify a preferred variable grouping.\n  \"\"\"\n  record_list_dict = filter_user_registered_records(record_list_dict,\n                                                    user_registered_variables)\n  record_list_dict = filter_grouped_variable_records(layer_collection,\n                                                     record_list_dict)\n  record_list_dict = filter_subgraph_records(record_list_dict)\n\n  # Look for any violation in the consistency of the remaining matches.\n  recorded_params = dict()\n  ambiguous_registration_errors = []\n  for params in record_list_dict:\n    for variable in ensure_sequence(params):\n      if variable in recorded_params:\n        ambiguous_registration_errors.append(\n            'Variable {} was recorded in multiple groups: {} and {}.'.format(\n                variable, params, recorded_params[variable]))\n      else:\n        recorded_params[variable] = params\n  if ambiguous_registration_errors:\n    raise AmbiguousRegistrationError('\\n'.join(ambiguous_registration_errors))\n\n  return record_list_dict\n\n\ndef register_records(layer_collection,\n                     record_list_dict,\n                     reuse=False,\n                     batch_size=None):\n  \"\"\"Registers the given records to layer_collection.\n\n  Args:\n    layer_collection: A `LayerCollection` to use for registering layers.\n    record_list_dict: A dict mapping tuples of variables to lists of\n      `MatchRecord`s representing all of the places those variables are used\n      in the graph.\n    reuse: (OPTIONAL) bool. If True, then `layer_collection`\n      selects a previously registered block with the same key as the key\n      derived from `params` of that block. If False, a new block is\n      registered.\n    batch_size: A `int` representing the batch size. Needs to specified if\n      registering generic variables that don't match any layer patterns or\n      if time/uses is folded. If the time/uses dimension is merged with\n      batch then this is used to infer number of uses/time-steps.\n\n  Raises:\n    ValueError: If record_list_dict contains multiple record types for a single\n      set of variables, or if there are multiple records for a set of variables\n      of a type that doesn't support shared parameters.\n    AmbiguousRegistrationError: If a batch norm layer registration is required\n      but batch_size is not passed.\n  \"\"\"\n\n  mixed_record_type_errors = []\n\n  # TODO(b/69627702): Layers must be registered in a deterministic order, else\n  # FisherFactors may end up with different variable names.\n  params_list = sorted(record_list_dict.keys(), key=str)\n  for params in params_list:\n    record_list = record_list_dict[params]\n    # We don't support mixed types for the same params and probably never\n    # will.\n    if not all(record_list[0].record_type == record.record_type\n               for record in record_list):\n      mixed_record_type_errors.append(\n          'Detected variables {} with mixed record types: {}.'.format(\n              params, record_list))\n      continue\n\n    record_type = record_list[0].record_type\n\n    if record_type is RecordType.fully_connected:\n\n      dense_inputs = record_list[0].data['dense_inputs']\n      if (not dense_inputs\n          and layer_collection._get_linked_approx(params) is None):  # pylint: disable=protected-access\n        # Nothing is lost by using a diagonal approx for the input factor here.\n        # This is because the 2nd-moment matrix for 1-hot vectors will be\n        # naturally diagonal.\n        approx = 'kron_indep_in_diag'\n      else:\n        approx = None\n\n      if len(record_list) > 1:\n        logging.info(\n            'Registering as multi-use fully-connected: {}'.format(params))\n\n        inputs = tuple(record.data['inputs'] for record in record_list)\n        outputs = tuple(record.data['outputs'] for record in record_list)\n        layer_collection.register_fully_connected_multi(\n            params, inputs, outputs, reuse=reuse, dense_inputs=dense_inputs,\n            approx=approx)\n      else:\n        if dense_inputs:\n          folded_dim_limit = 2\n        else:\n          folded_dim_limit = 1\n\n        record = record_list[0]\n        inputs = record.data['inputs']\n        outputs = record.data['outputs']\n\n        first_dim = inputs.shape.as_list()[0]\n        num_dim = len(inputs.shape)\n\n        is_batch_time_folded = not (\n            batch_size is None\n            or first_dim is None\n            or first_dim == batch_size\n            or num_dim > folded_dim_limit)\n\n        if is_batch_time_folded or num_dim > folded_dim_limit:\n          logging.info(\n              'Registering as multi-use fully-connected: {}'.format(params))\n\n          logging.warning('Registering {} as multi-use fully-connected layer '\n                          'with folded batch and time/use dimension. If using '\n                          'the non-independent K-FAC RNNs approximations ('\n                          '\"Option 1\" or \"Option 2\") make sure that the '\n                          'dimensions are ordered [time/use, batch] before '\n                          'folding, and not the other way around. Otherwise '\n                          'you will get a silent failure of the method!'\n                          ''.format(params))\n\n          if is_batch_time_folded:\n            if first_dim % batch_size != 0:\n              raise ValueError('Passed batch_size did not divide first '\n                               'dimension of tensor with presumed folded '\n                               'batch and use/times dimension. Possible causes '\n                               'include passing the wrong batch size (e.g. '\n                               'passing overall instead of per-replica), or a '\n                               'non-standard layer (possibly with no batch '\n                               'dependency). Layer params are: '\n                               '{}. Input and output tensors are: {} and {}'\n                               ''.format(params, inputs, outputs))\n            num_uses = first_dim // batch_size\n          else:\n            num_uses = record_list[0].data['inputs'].shape.as_list()[1]\n\n          layer_collection.register_fully_connected_multi(\n              params, inputs, outputs, num_uses=num_uses, reuse=reuse,\n              dense_inputs=dense_inputs, approx=approx)\n        else:\n          logging.info('Registering as fully-connected: {}'.format(params))\n          layer_collection.register_fully_connected(\n              params, inputs, outputs, reuse=reuse, dense_inputs=dense_inputs,\n              approx=approx)\n\n    elif record_type is RecordType.conv2d:\n      if len(record_list) > 1:\n        logging.info('Registering as multi-use conv2d: {}'.format(params))\n\n        inputs = tuple(record.data['inputs'] for record in record_list)\n        outputs = tuple(record.data['outputs'] for record in record_list)\n        strides = record_list[0].data['strides']\n        padding = record_list[0].data['padding']\n        data_format = record_list[0].data['data_format']\n        layer_collection.register_conv2d_multi(\n            params,\n            strides,\n            padding,\n            inputs,\n            outputs,\n            data_format=data_format,\n            reuse=reuse)\n      else:\n        record = record_list[0]\n        inputs = record.data['inputs']\n        outputs = record.data['outputs']\n        strides = record.data['strides']\n        padding = record.data['padding']\n        data_format = record.data['data_format']\n\n        first_dim = inputs.shape.as_list()[0]\n        num_dim = len(inputs.shape)\n\n        is_batch_time_folded = not (\n            batch_size is None\n            or first_dim is None\n            or first_dim == batch_size\n            or num_dim > 4)\n\n        if is_batch_time_folded or num_dim > 4:\n          logging.info('Registering as multi-use conv2d: {}'.format(params))\n\n          if is_batch_time_folded:\n            if first_dim % batch_size != 0:\n              raise ValueError('Passed batch_size did not divide first '\n                               'dimension of tensor with presumed folded '\n                               'batch and use/times dimension. Possible causes '\n                               'include passing the wrong batch size (e.g. '\n                               'passing overall instead of per-replica), or a '\n                               'non-standard layer (possibly with no batch '\n                               'dependency). Layer params are: '\n                               '{}. Input and output tensors are: {} and {}'\n                               ''.format(params, inputs, outputs))\n            num_uses = first_dim // batch_size\n          else:\n            raise ValueError('Currently not supporting conv layers with '\n                             'separate time/uses dim.')\n\n          layer_collection.register_conv2d_multi(\n              params,\n              strides,\n              padding,\n              inputs,\n              outputs,\n              data_format=data_format,\n              num_uses=num_uses,\n              reuse=reuse)\n        else:\n          logging.info('Registering as conv2d: {}'.format(params))\n          layer_collection.register_conv2d(params, strides, padding, inputs,\n                                           outputs, data_format=data_format,\n                                           reuse=reuse)\n\n    elif record_type is RecordType.scale_and_shift:\n      logging.info('Registering as scale (& shift): {}'.format(params))\n\n      if len(record_list) > 1:\n        raise ValueError('Multi-use registrations currently not supported for '\n                         'scale & shift operations.')\n      record = record_list[0]\n      inputs = record.data['inputs']\n      outputs = record.data['outputs']\n\n      layer_collection.register_scale_and_shift(params, inputs, outputs,\n                                                reuse=reuse)\n\n    elif record_type is RecordType.batch_norm:\n      # For now we register this as generic instead of scale_and_shift because\n      # the fused version of batch norm won't give us the quantities we need\n      # for the latter. Could consider splitting this into fused and non-fused\n      # cases.\n\n      logging.info('Registering as generic (batch norm): {}'.format(params))\n\n      if batch_size is None:\n        raise AmbiguousRegistrationError(\n            'Tried to register a batch norm layer (as generic) without '\n            'knowledge of batch_size. You can pass batch_size in to fix this '\n            'error.')\n\n      # This is a slight hack. Ideally register_generic would work with lists\n      # of params like it used to before we switched to the \"unflattened\" cov\n      # representation so we wouldn't need to detect the approximation type.\n      will_use_diag = (\n          layer_collection._get_linked_approx(params) == 'diagonal'  # pylint: disable=protected-access\n          or (layer_collection.default_generic_approximation == 'diagonal'\n              and layer_collection._get_linked_approx(params) is None)  # pylint: disable=protected-access\n          )\n      if will_use_diag:\n        for param in ensure_sequence(params):\n          layer_collection.register_generic(param, batch_size, reuse=reuse)\n      else:\n        layer_collection.register_generic(params, batch_size, reuse=reuse)\n\n    else:\n      assert False, 'Invalid record type {}'.format(record_type)\n\n  if mixed_record_type_errors:\n    raise ValueError('\\n'.join(mixed_record_type_errors))\n"
  },
  {
    "path": "kfac/python/ops/tensormatch/tensorflow_graph_util.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Abstraction layer for working with the TensorFlow graph model.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# Dependency imports\nimport six\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.ops import resource_variable_ops\nfrom kfac.python.ops import utils\n# pylint: disable=g-import-not-at-top\ntry:\n  from tensorflow.python.types import core\nexcept ModuleNotFoundError:\n  from tensorflow.python.framework import ops as tf_ops\n# pylint: enable=g-import-not-at-top\n\n\ndef is_op(node):\n  return isinstance(node, tf.Operation)\n\n\ndef is_tensor(node):\n  try:\n    # TODO(b/154650521): Use tf.Tensor instead of core.Tensor.\n    return isinstance(node, core.Tensor)\n  except NameError:\n    return tf_ops.is_dense_tensor_like(node)\n\n\ndef is_var(node):\n  if not is_tensor(node):\n    return False\n  if node.op.type.startswith('Variable'):\n    return True\n  if ((resource_variable_ops.is_resource_variable(node) or\n       utils.is_reference_variable(node))):\n    return True\n  # TODO(b/143690035): Note that the Placeholder type handles the Control Flow\n  # V2 case, but this could stop working in the future if the implementation of\n  # Control Flow V2 changes.\n  if node.dtype == tf.resource and (node.op.type == 'VarHandleOp'\n                                    or node.op.type == 'Placeholder'):\n    return True\n  return False\n\n\ndef is_const(node):\n  return is_tensor(node) and node.op.type == 'Const'\n\n\ndef is_placeholder(node):\n  return is_tensor(node) and node.op.type == 'Placeholder'\n\n\ndef is_leaf(node):\n  return is_var(node) or is_const(node) or is_placeholder(node)\n\n\ndef is_identity(node):\n  if not is_op(node):\n    return False\n  # For ResourceVariables, a 'ReadVariableOp' has a single 'Enter' input, which\n  # in turn has a Tensor with dtype == resource as input.\n  return (node.type in {'Identity', 'ReadVariableOp', 'Enter', 'IdentityN'}\n          or 'convert_gradient_to_tensor' in node.type)\n\n\ndef op_type_is(typename):\n\n  def is_op_with_typename(node):\n    return is_op(node) and node.type == typename\n\n  return is_op_with_typename\n\n\ndef reduce_identity_ops(node):\n  while is_tensor(node) and is_identity(node.op):\n    # IdentityN is sometimes used when custom gradients are involved. Its\n    # two inputs should be the same in that case. Otherwise there should only\n    # be one input.\n    assert (len(node.op.inputs) == 1\n            or (node.op.type == 'IdentityN'\n                and node.op.inputs[0] == node.op.inputs[1]))\n\n    node = node.op.inputs[0]\n  return node\n\n\ndef expand_inputs(node):\n  \"\"\"Return a list of input nodes for a given TF graph node (or node list).\"\"\"\n  if is_op(node):\n    return [reduce_identity_ops(tensor) for tensor in node.inputs[:]]\n  elif is_tensor(node) and not is_leaf(node):\n    return [reduce_identity_ops(node).op]\n  elif isinstance(node, list) and all(is_tensor(elt) for elt in node):\n    ops = {reduce_identity_ops(tensor).op for tensor in node}\n    if len(ops) == 1:\n      return [ops.pop()]\n    raise ValueError\n  return None\n\n\ndef expand_outputs(node):\n  \"\"\"Return a list of output nodes for a given TF graph node.\"\"\"\n  if is_op(node):\n    return node.outputs[:]\n  elif isinstance(node, tf.Variable):\n    return node.value().consumers()\n  elif is_tensor(node):\n    return node.consumers()\n  return None\n\n\ndef make_op_pattern(typename):\n  \"\"\"Makes a pattern that matches a given Op type.\"\"\"\n\n  def op_fun(name=None):\n    return ('?', name, op_type_is(typename))\n\n  op_fun_name = typename.encode('ascii', 'ignore')\n\n  # In Python 3, str.encode() produces a bytes object. Convert this to an ASCII\n  # str.\n  if six.PY3:\n    op_fun_name = op_fun_name.decode('ascii')\n\n  op_fun.__name__ = op_fun_name\n  return op_fun\n\n\ndef import_ops_no_clobber(dct, op_names):\n  for name in op_names:\n    if name not in dct:\n      dct[name] = make_op_pattern(name)\n"
  },
  {
    "path": "kfac/python/ops/utils.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utility functions.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n# Dependency imports\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.python.tpu import tpu_function\nfrom tensorflow.python.tpu.ops import tpu_ops\nfrom tensorflow.python.ops import resource_variable_ops\nfrom tensorflow.python.util import nest\n\n# Method used for inverting matrices.\nPOSDEF_INV_METHOD = \"cholesky\"\nPOSDEF_EIG_METHOD = \"self_adjoint\"\n\n_TF_REPLICATOR = None\n\n\ndef smart_assign(variable, value, assign_fn=tf.assign,\n                 force_cast=False, force_sync=True):\n  \"\"\"Calls assign_fn on variable and value in a cross-replica context.\n\n  When this function is called in a per-replica context, it will enter a cross-\n  replica context before calling assign_fn(variable, value). During training\n  with a tf.distribute.Strategy, optimizer.minimize is always called in a per-\n  replica context (e.g. via experimental_run for TPUStrategy). Since with this\n  function we assign a synchronized Tensor to a MirroredVariable with assign_fn,\n  we use a merge_call to enter a cross-replica context, then use\n  distribution.extended.update to assign value to variable with assign_fn.\n\n  When this function is called in a cross-replica context or outside of a\n  tf.distribute.Strategy scope, smart_assign will use assign_fn as is.\n  Operations that happen inside of a tf.distribute.Strategy scope\n  are typically in a cross replica context, unless, for example, they happen in\n  an experimental_run call or a call_for_each_replica call.  In a cross-replica\n  context, tf.distribute.get_replica_context() returns None.\n\n  Args:\n    variable: TF Variable. A MirroredVariable when in a distribution strategy.\n    value: TF Tensor. This function will throw an error if value is a PerReplica\n      type, which means it is an unsynchronized Tensor. You must reduce it using\n      all_sum or all_average before using this function.\n    assign_fn: assign_fn(variable, value) -> tf.Operation. The function\n      used to update variable with value, typically tf.assign, tf.assign_add,\n      or tf.assign_sub.\n    force_cast: Boolean. If True we cast the `value` to the dtype of `variable`\n      when they don't match. (Default: False)\n    force_sync: Boolean. If True and using MirroredStrategy in a replica\n      context, take the mean of value over all replicas to force the value to be\n      syncronized before performing the assignment.\n\n  Returns:\n    tf.Tensor that contains the result of assign_fn(variable, value) called in\n    a cross-replica context.\n  \"\"\"\n  if force_cast and variable.dtype != value.dtype:\n    value = tf.cast(value, dtype=variable.dtype)\n\n  if not (tf.distribute.has_strategy() and tf.distribute.get_replica_context()):\n    return assign_fn(variable, value)\n\n  def merge_fn(distribution, variable, value):\n    strategy = tf.distribute.get_strategy()\n    if isinstance(strategy, tf.distribute.MirroredStrategy) and force_sync:\n      value = strategy.reduce(tf.distribute.ReduceOp.MEAN, value)\n    return distribution.extended.update(variable, assign_fn, args=(value,))\n\n  return tf.distribute.get_replica_context().merge_call(\n      merge_fn, args=(variable, value))\n\n\ndef smart_cond(predicate, true_fn, false_fn, name=None):\n  \"\"\"Creates ops for conditionally executing one of two functions.\n\n  If MirroredStrategy is not used or outside of a MirroredStrategy replica\n  context, this is identical to tf.cond.\n  tf.cond does not support using functions which involve synchronization calls\n  inside a MirroredStrategy replica context. Instead, work around this by safely\n  evaluating the conditional across replicas and then evaluate either true_fn or\n  false_fn back in a replica context.\n\n  Note: this is only required if true_fn and/or false_fn involve a\n  synchronization across replicas (e.g. via a reduction to evaluate the\n  cross-replica mean).\n\n  Limitations: with MirroredStrategy, true_fn and false_fn are executed via\n  control_dependencies are a constant tensor is returned instead of the actual\n  return values of true_fn and false_fn. This is due to the requirement that\n  functions executed using DistributionStrategy.call_for_each_replica return a\n  tensor rather than an operation.\n\n  Args:\n    predicate: boolean operation which determines whether to execute true_fn or\n      false_fn.\n    true_fn: function to execute if predicate is true.\n    false_fn: function to execute if predicate is false.\n    name: name to assign to the tf.cond operation.\n\n  Returns:\n    If not using MirroredStrategy or outside of a MirroredStrategy replica\n    context, the result from true_fn or false_fn, and otherwise a constant\n    tensor.\n  \"\"\"\n  if (tf.distribute.has_strategy() and tf.distribute.get_replica_context()):\n    strategy = tf.distribute.get_strategy()\n  else:\n    strategy = None\n  if not isinstance(strategy, tf.distribute.MirroredStrategy):\n    return tf.cond(predicate, true_fn, false_fn, name)\n  else:\n    # Conditionals with functions which execute synchronization calls are not\n    # well supported with Distribution Strategy. Instead follow the scheme\n    # suggested in https://github.com/tensorflow/tensorflow/issues/27716:\n    # 1. Execute the conditional in a cross-replica context.\n    # 2. The conditional functions then return to a replica-context before\n    # executing the original conditional functions.\n    def true_fn_per_replica():\n      # call_for_each_replica requires a tensor to be returned. This is not true\n      # for all functions (which, e.g., might return an op or tf.group) so\n      # instead execute the ops as control dependency and return a constant\n      # tensor.\n      with tf.control_dependencies([true_fn()]):\n        return tf.constant(0.0)\n    def true_fn_cross_replica():\n      strategy = tf.distribute.get_strategy()\n      return strategy.extended.call_for_each_replica(true_fn_per_replica)\n    def false_fn_per_replica():\n      with tf.control_dependencies([false_fn()]):\n        return tf.constant(0.0)\n    def false_fn_cross_replica():\n      strategy = tf.distribute.get_strategy()\n      return strategy.extended.call_for_each_replica(false_fn_per_replica)\n    def cond(distribution):\n      del distribution\n      return tf.cond(predicate, true_fn_cross_replica, false_fn_cross_replica, name)\n    return tf.distribute.get_replica_context().merge_call(cond)\n\n\ndef set_global_constants(posdef_inv_method=None, tf_replicator=None):\n  \"\"\"Sets various global constants used by the classes in this module.\"\"\"\n  global POSDEF_INV_METHOD\n  global _TF_REPLICATOR\n\n  if posdef_inv_method is not None:\n    POSDEF_INV_METHOD = posdef_inv_method\n  if tf_replicator is not None:\n    _TF_REPLICATOR = tf_replicator\n\n\nclass SequenceDict(object):\n  \"\"\"A dict convenience wrapper that allows getting/setting with sequences.\"\"\"\n\n  def __init__(self, iterable=None):\n    self._dict = dict(iterable or [])\n\n  def __getitem__(self, key_or_keys):\n    if isinstance(key_or_keys, (tuple, list)):\n      return list(map(self.__getitem__, key_or_keys))\n    else:\n      return self._dict[key_or_keys]\n\n  def __setitem__(self, key_or_keys, val_or_vals):\n    if isinstance(key_or_keys, (tuple, list)):\n      for key, value in zip(key_or_keys, val_or_vals):\n        self[key] = value\n    else:\n      self._dict[key_or_keys] = val_or_vals\n\n  def items(self):\n    return list(self._dict.items())\n\n\ndef tensors_to_column(tensors):\n  \"\"\"Converts a tensor or list of tensors to a column vector.\n\n  Args:\n    tensors: A tensor or list of tensors.\n\n  Returns:\n    The tensors reshaped into vectors and stacked on top of each other.\n  \"\"\"\n  if isinstance(tensors, (tuple, list)):\n    return tf.concat(\n        tuple(tf.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)\n  else:\n    return tf.reshape(tensors, [-1, 1])\n\n\ndef column_to_tensors(tensors_template, colvec):\n  \"\"\"Converts a column vector back to the shape of the given template.\n\n  Args:\n    tensors_template: A tensor or list of tensors.\n    colvec: A 2d column vector with the same shape as the value of\n        tensors_to_column(tensors_template).\n\n  Returns:\n    X, where X is tensor or list of tensors with the properties:\n     1) tensors_to_column(X) = colvec\n     2) X (or its elements) have the same shape as tensors_template (or its\n        elements)\n  \"\"\"\n  if isinstance(tensors_template, (tuple, list)):\n    offset = 0\n    tensors = []\n    for tensor_template in tensors_template:\n      sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)\n      tensor = tf.reshape(colvec[offset:(offset + sz)], tensor_template.shape)\n      tensors.append(tensor)\n      offset += sz\n\n    tensors = tuple(tensors)\n  else:\n    tensors = tf.reshape(colvec, tensors_template.shape)\n\n  return tensors\n\n\ndef kronecker_product(mat1, mat2):\n  \"\"\"Computes the Kronecker product two matrices.\"\"\"\n  m1, n1 = mat1.get_shape().as_list()\n  mat1_rsh = tf.reshape(mat1, [m1, 1, n1, 1])\n  m2, n2 = mat2.get_shape().as_list()\n  mat2_rsh = tf.reshape(mat2, [1, m2, 1, n2])\n  return tf.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])\n\n\ndef layer_params_to_mat2d(vector):\n  \"\"\"Converts a vector shaped like layer parameters to a 2D matrix.\n\n  In particular, we reshape the weights/filter component of the vector to be\n  2D, flattening all leading (input) dimensions. If there is a bias component,\n  we concatenate it to the reshaped weights/filter component.\n\n  Args:\n    vector: A Tensor or pair of Tensors shaped like layer parameters.\n\n  Returns:\n    A 2D Tensor with the same coefficients and the same output dimension.\n  \"\"\"\n  if isinstance(vector, (tuple, list)):\n    w_part, b_part = vector\n    w_part_reshaped = tf.reshape(w_part, [-1, w_part.shape.as_list()[-1]])\n    return tf.concat((w_part_reshaped, tf.reshape(b_part, [1, -1])), axis=0)\n  elif isinstance(vector, tf.IndexedSlices):\n    return vector\n  else:  # Tensor or Tensor-like.\n    return tf.reshape(vector, [-1, vector.shape.as_list()[-1]])\n\n\ndef mat2d_to_layer_params(vector_template, mat2d):\n  \"\"\"Converts a canonical 2D matrix representation back to a vector.\n\n  Args:\n    vector_template: A Tensor or pair of Tensors shaped like layer parameters.\n    mat2d: A 2D Tensor with the same shape as the value of\n        layer_params_to_mat2d(vector_template).\n\n  Returns:\n    A Tensor or pair of Tensors with the same coefficients as mat2d and the same\n        shape as vector_template.\n  \"\"\"\n  if isinstance(vector_template, (tuple, list)):\n    w_part, b_part = mat2d[:-1], mat2d[-1]\n    return tf.reshape(w_part, vector_template[0].shape), b_part\n  elif isinstance(vector_template, tf.IndexedSlices):\n    if not isinstance(mat2d, tf.IndexedSlices):\n      raise TypeError(\n          \"If vector_template is an IndexedSlices, so should mat2d.\")\n    return mat2d\n  else:\n    return tf.reshape(mat2d, vector_template.shape)\n\n\ndef posdef_inv(tensor, damping):\n  \"\"\"Computes the inverse of tensor + damping * identity.\"\"\"\n  identity = tf.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)\n  damping = tf.cast(damping, dtype=tensor.dtype)\n  return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)\n\n\ndef posdef_inv_matrix_inverse(tensor, identity, damping):\n  \"\"\"Computes inverse(tensor + damping * identity) directly.\"\"\"\n  return tf.matrix_inverse(tensor + damping * identity)\n\n\ndef posdef_inv_cholesky(tensor, identity, damping):\n  \"\"\"Computes inverse(tensor + damping * identity) with Cholesky.\"\"\"\n  chol = tf.linalg.cholesky(tensor + damping * identity)\n  return tf.linalg.cholesky_solve(chol, identity)\n\n\ndef posdef_inv_eig(tensor, identity, damping):\n  \"\"\"Computes inverse(tensor + damping * identity) with eigendecomposition.\"\"\"\n  eigenvalues, eigenvectors = tf.self_adjoint_eig(tensor + damping * identity)\n  return tf.matmul(eigenvectors / eigenvalues, eigenvectors, transpose_b=True)\n\n\nposdef_inv_functions = {\n    \"matrix_inverse\": posdef_inv_matrix_inverse,\n    \"cholesky\": posdef_inv_cholesky,\n    \"eig\": posdef_inv_eig,\n}\n\n\ndef posdef_eig(mat):\n  \"\"\"Computes the eigendecomposition of a positive semidefinite matrix.\"\"\"\n  return posdef_eig_functions[POSDEF_EIG_METHOD](mat)\n\n\ndef posdef_eig_svd(mat):\n  \"\"\"Computes the singular values and left singular vectors of a matrix.\"\"\"\n  evals, evecs, _ = tf.svd(mat)\n\n  return evals, evecs\n\n\ndef posdef_eig_self_adjoint(mat):\n  \"\"\"Computes eigendecomposition using self_adjoint_eig.\"\"\"\n  evals, evecs = tf.self_adjoint_eig(mat)\n  evals = tf.abs(evals)  # Should be equivalent to svd approach.\n\n  return evals, evecs\n\n\nposdef_eig_functions = {\n    \"self_adjoint\": posdef_eig_self_adjoint,\n    \"svd\": posdef_eig_svd,\n}\n\n\ndef cholesky(tensor, damping):\n  \"\"\"Computes the inverse of tensor + damping * identity.\"\"\"\n  identity = tf.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)\n  damping = tf.cast(damping, dtype=tensor.dtype)\n  return tf.linalg.cholesky(tensor + damping * identity)\n\n\nclass SubGraph(object):\n  \"\"\"Defines a subgraph given by all the dependencies of a given set of outputs.\n  \"\"\"\n\n  def __init__(self, outputs):\n    # Set of all ancestor Tensors, Ops to 'outputs'.\n    self._members = set()\n\n    self._iter_add(outputs)\n    self._graph = outputs[0].graph\n\n  def _iter_add(self, root):\n    \"\"\"Iteratively adds all of nodes' ancestors using depth first search.\"\"\"\n    stack = [root]\n    while stack:\n      nodes = stack.pop()\n      for node in nodes:\n        if node in self._members:\n          continue\n        self._members.add(node)\n\n        if isinstance(node, tf.Tensor):\n          stack.append((node.op,))\n        elif isinstance(node, tf.Operation):\n          stack.append(node.inputs)\n\n  def is_member(self, node):\n    \"\"\"Check if 'node' is in this subgraph.\"\"\"\n    return node in self._members\n\n  def variable_uses(self, var):\n    \"\"\"Computes number of times a variable is used.\n\n    Args:\n      var: Variable or ResourceVariable instance.\n\n    Returns:\n      Number of times a variable is used within this subgraph.\n\n    Raises:\n      ValueError: If 'var' is not a variable type.\n    \"\"\"\n    def _add_tensor_consumers_to_set(tensor, consumers_set):\n      \"\"\"Finds consumers of a tensor and add them to the current consumers set.\n      \"\"\"\n      for consumer in set(tensor.consumers()):\n        # These are the type of ops which relay a tensor to other ops without\n        # doing anything to the tensor value, so recursively find the actual\n        # consumers.\n        if consumer.type in [\n            \"Identity\", \"ReadVariableOp\", \"Enter\", \"ResourceGather\"]:\n          for output in consumer.outputs:\n            _add_tensor_consumers_to_set(output, consumers_set)\n        else:\n          consumers_set.add(consumer)\n\n    consumers = set()\n    if resource_variable_ops.is_resource_variable(var):\n      if tf.control_flow_v2_enabled() and hasattr(self._graph, \"captures\"):\n        # TODO(b/143690035): Note that the \"captures\" property relies on an API\n        # which might change.\n        captures = self._graph.captures\n        for handle in [h for vh, h in captures if vh is var.handle]:\n          _add_tensor_consumers_to_set(handle, consumers)\n      else:\n        _add_tensor_consumers_to_set(var.handle, consumers)\n    elif is_reference_variable(var):\n      _add_tensor_consumers_to_set(var.value(), consumers)\n    else:\n      raise ValueError(\"%s does not appear to be a variable.\" % str(var))\n\n    return len(self._members.intersection(consumers))\n\n  def filter_list(self, node_list):\n    \"\"\"Filters 'node_list' to nodes in this subgraph.\"\"\"\n    filtered_list = []\n    for node in node_list:\n      if self.is_member(node):\n        filtered_list.append(node)\n    return filtered_list\n\n\ndef preferred_int_dtype():\n  # tf.int32 doesn't work properly on GPUs, and tf.int64 isn't recommended on\n  # TPUs. Hence this function.\n  if is_tpu_replicated():\n    return tf.int32\n  else:\n    return tf.int64\n\n\ndef generate_random_signs(shape, dtype=tf.float32):\n  \"\"\"Generate a random tensor with {-1, +1} entries.\"\"\"\n  ints = tf.random_uniform(shape, maxval=2, dtype=preferred_int_dtype())\n  return 2 * tf.cast(ints, dtype=dtype) - 1\n\n\n# MirroredVariables do not have a hashable op property, which means they cannot\n# be used with stop_gradients. This was fixed in the TF-Nightly release, but is\n# not in any stable release, so we use the below hack so our fwd_gradients\n# function works in the TF 1.14 stable release.\n# TODO(b/139376871): Remove this workaround once the bugfix is in a stable release.\nDistributedVarOp = collections.namedtuple(\n    \"DistributedVarOp\", [\"name\", \"graph\", \"traceback\", \"type\"])\n\nclass MirroredVariableWrapper(object):\n\n  def __init__(self, var):\n    self.__var = var\n\n  def __getattr__(self, name):\n    if name == 'op':\n      return DistributedVarOp(\n          self.__var.op.name,\n          self.__var.op.graph,\n          # In the updated TF codebase, convert_stack returns tuple instead of\n          # list, which makes op.traceback hashable.\n          tuple(self.__var.op.traceback),\n          self.__var.op.type)\n    else:\n      return getattr(self.__var, name)\n\n\ndef _as_list(x):\n  return x if isinstance(x, (list, tuple)) else [x]\n\n\ndef fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None,\n                  colocate_gradients_with_ops=True):\n  \"\"\"Compute forward-mode gradients.\"\"\"\n  # See b/37888268.\n\n  # This version of forward-mode autodiff is based on code by Tim Cooijmans\n  # and handles list arguments and certain special cases such as when the\n  # ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are\n  # generated by the first tf.gradients call.\n\n  ys = _as_list(ys)\n  xs = _as_list(xs)\n\n  us = [tf.zeros_like(y) + float(\"nan\") for y in ys]\n  if tf.distribute.has_strategy():\n    stop_gradients = [MirroredVariableWrapper(v) for v in stop_gradients]\n  dydxs = tf.gradients(ys, xs, grad_ys=us, stop_gradients=stop_gradients,\n                       colocate_gradients_with_ops=colocate_gradients_with_ops)\n\n  # Deal with strange types that tf.gradients returns but can't\n  # deal with.\n  dydxs = [\n      tf.convert_to_tensor(dydx) if isinstance(dydx, tf.IndexedSlices) else dydx\n      for dydx in dydxs\n  ]\n  dydxs = [\n      tf.zeros_like(x) if dydx is None else dydx for x, dydx in zip(xs, dydxs)\n  ]\n  dysdx = tf.gradients(dydxs, us, grad_ys=grad_xs,\n                       colocate_gradients_with_ops=colocate_gradients_with_ops)\n  return dysdx\n\n\ndef get_tf_replicator():\n  return _TF_REPLICATOR\n\n\ndef is_tpu_replicated():\n  is_tpu_strategy = (tf.distribute.has_strategy() and\n                     tf.distribute.get_replica_context() and\n                     isinstance(tf.distribute.get_strategy(),\n                                tf.distribute.experimental.TPUStrategy))\n  num_shards = tpu_function.get_tpu_context().number_of_shards\n  return is_tpu_strategy or num_shards is not None\n\n\ndef is_replicated():\n  \"\"\"Check if we are operating in a supported replicated context.\"\"\"\n  if tf.distribute.has_strategy() and tf.distribute.get_replica_context():\n    return tf.distribute.get_strategy().num_replicas_in_sync > 1\n  return get_tf_replicator() is not None or is_tpu_replicated()\n\n\ndef get_num_replicas():\n  \"\"\"Returns the number of replicas.\n\n  If not operating in a supported replicated context this function will return\n  1.\n  \"\"\"\n\n  tf_replicator = get_tf_replicator()\n\n  if tf_replicator:\n    return tf_replicator.num_replicas_in_sync\n  elif tf.distribute.has_strategy():\n    return tf.distribute.get_strategy().num_replicas_in_sync\n  else:\n    # I'm assuming replicas and shards are always equal until someone tells me\n    # different.\n    num_replicas = tpu_function.get_tpu_context().number_of_shards\n    if num_replicas:\n      return num_replicas\n    else:\n      return 1\n\n\ndef get_replica_id():\n  \"\"\"Returns an id number for the current replica, counting from 0.\n\n  If not operating in a supported replicated context this function will return\n  0.\n  \"\"\"\n\n  tf_replicator = get_tf_replicator()\n\n  if tf_replicator:\n    return tf_replicator.current_replica_id\n  elif tf.distribute.has_strategy() and tf.distribute.get_replica_context():\n    return tf.distribute.get_replica_context().replica_id_in_sync_group\n\n  # This code below this point is based on\n  # TensorTracer._add_replica_id_to_graph().\n  num_replicas = get_num_replicas()\n\n  if num_replicas <= 1:\n    return 0\n\n  with tf.control_dependencies(None):\n    # Uses None as dependency to run outside of TPU graph rewrites.\n    return tpu_ops.tpu_replicated_input(list(range(num_replicas)),\n                                        name=\"replica_id\")\n\n\ndef all_sum(structure, name=None):\n  \"\"\"Sums the contents of a nested structure across all replicas.\n\n  If not operating in a supported replicated context this function acts like\n  the identity.\n\n  Args:\n    structure: A nested structure of Tensors.\n    name: None or string. Optional name of Op. (Default: None)\n\n  Returns:\n    A nested structure with the corresponding Tensors being the cross-replica\n    summed versions of those in `structure`.\n  \"\"\"\n  num_replicas = get_num_replicas()\n\n  if num_replicas <= 1:\n    return structure\n\n  tf_replicator = get_tf_replicator()\n  if tf_replicator:\n    return tf_replicator.all_sum(structure)\n\n  elif tf.distribute.has_strategy() and tf.distribute.get_replica_context():\n    return tf.distribute.get_replica_context().all_reduce(\n        tf.distribute.ReduceOp.SUM, structure)\n\n  elif is_tpu_replicated():\n    def tpu_all_sum(tensor):\n      return tpu_ops.cross_replica_sum(tensor, name=name)\n\n    return nest.map_structure(tpu_all_sum, structure)\n\n  return structure\n\n\ndef all_average(structure, name=None):\n  \"\"\"Averages the contents of a nested structure across all replicas.\n\n  If not operating in a supported replicated context this function acts like\n  the identity.\n\n  Args:\n    structure: A nested structure of Tensors.\n    name: None or string. Optional name of Op. (Default: None)\n\n  Returns:\n    A nested structure with the corresponding Tensors being the cross-replica\n    averaged versions of those in `structure`.\n  \"\"\"\n  num_replicas = get_num_replicas()\n\n  if num_replicas <= 1:\n    return structure\n\n  if (tf.distribute.has_strategy() and tf.distribute.get_replica_context()\n      and not get_tf_replicator()):\n    return tf.distribute.get_replica_context().all_reduce(\n        tf.distribute.ReduceOp.MEAN, structure)\n\n  return nest.map_structure(lambda x: x / num_replicas, all_sum(structure,\n                                                                name=name))\n\n\ndef map_gather(thunks, name=None):\n  \"\"\"Distributes the execution of thunks over replicas, then gathers results.\n\n    This method can be used to distribute several expensive computations across\n    the replicas, rather than duplicating the computation in all of them.\n\n  Args:\n    thunks: A list of thunks that each returns a nested structure of Tensors.\n      These should all have statically known shapes.\n    name: None or string. Optional name of Op. (Default: None)\n\n  Returns:\n    A list of nested structures of Tensors representing the return values of\n    the list of thunks.\n  \"\"\"\n\n  num_replicas = get_num_replicas()\n\n  if num_replicas <= 1:\n    return tuple(thunk() for thunk in thunks)\n\n  tf_replicator = get_tf_replicator()\n\n  if tf_replicator:\n    return tf_replicator.map_gather(thunks, lambda thunk: thunk())\n\n  elif is_tpu_replicated():\n    replica_id = get_replica_id()\n\n    def zeros_like(tensor):\n      return tf.zeros(dtype=tensor.dtype, shape=tensor.shape)\n\n    results = []\n    for idx, thunk in enumerate(thunks):\n      # TensorFlow's optimization should eliminate the actual computations\n      # done to compute example_structure, using only the (static) shape\n      # information.\n      def make_zeros_thunk(example_structure):\n        def zeros_thunk():\n          return nest.map_structure(zeros_like, example_structure)\n        return zeros_thunk\n\n      # This trick of using cross_replica_sum with tensors of zeros is\n      # obviously wasteful in terms of commmunication. A better solution would\n      # involve only communicating the tensors from replicas where `include_me`\n      # was True.\n      include_me = tf.equal(replica_id, idx % num_replicas)\n      results.append(\n          all_sum(tf.cond(include_me,\n                          thunk,\n                          make_zeros_thunk(thunk()),\n                          strict=True),\n                  name=name))\n\n    return results\n\n  return tuple(thunk() for thunk in thunks)\n\n\ndef ensure_sequence(obj):\n  \"\"\"If `obj` isn't a tuple or list, return a tuple containing `obj`.\"\"\"\n  if isinstance(obj, (tuple, list)):\n    return obj\n  else:\n    return (obj,)\n\n\ndef batch_execute(global_step, thunks, batch_size, name=None):\n  \"\"\"Executes a subset of ops per global step.\n\n  Given a list of thunks, each of which produces a single stateful op,\n  ensures that exactly 'batch_size' ops are run per global step. Ops are\n  scheduled in a round-robin fashion. For example, with 3 ops\n\n    global_step | op0 | op1 | op2\n    ------------+-----+-----+-----\n        0       |  x  |  x  |\n    ------------+-----+-----+-----\n        1       |  x  |     |  x\n    ------------+-----+-----+-----\n        2       |     |  x  |  x\n    ------------+-----+-----+-----\n        3       |  x  |  x  |\n    ------------+-----+-----+-----\n        4       |  x  |     |  x\n\n  Does not guarantee order of op execution within a single global step.\n\n  Args:\n    global_step: Tensor indicating time. Determines which ops run.\n    thunks: List of thunks. Each thunk encapsulates one op. Return values are\n      ignored.\n    batch_size: int. Number of ops to execute per global_step.\n    name: string or None. Name scope for newly added ops.\n\n  Returns:\n    List of ops. Exactly 'batch_size' ops are guaranteed to have an effect\n    every global step.\n  \"\"\"\n\n  def true_fn(thunk):\n    \"\"\"Ensures thunk is executed and returns an Op (not a Tensor).\"\"\"\n\n    def result():\n      with tf.control_dependencies([thunk()]):\n        return tf.no_op()\n\n    return result\n\n  def false_fn(_):\n    \"\"\"Executes a no-op.\"\"\"\n\n    def result():\n      return tf.no_op()\n\n    return result\n\n  with tf.name_scope(name, \"batch_execute\"):\n    true_fns = [true_fn(thunk) for thunk in thunks]\n    false_fns = [false_fn(thunk) for thunk in thunks]\n    num_thunks = len(thunks)\n    conditions = [\n        tf.less(\n            tf.mod(batch_size - 1 + global_step * batch_size - j, num_thunks),\n            batch_size) for j in range(num_thunks)\n    ]\n    result = [\n        tf.cond(condition, true_fn, false_fn)\n        for (condition, true_fn,\n             false_fn) in zip(conditions, true_fns, false_fns)\n    ]\n    return result\n\n\ndef extract_convolution_patches(inputs,\n                                filter_shape,\n                                padding,\n                                strides=None,\n                                dilation_rate=None,\n                                name=None,\n                                data_format=None):\n  \"\"\"Extracts inputs to each output coordinate in tf.nn.convolution.\n\n  This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),\n  where the number of spatial dimensions may be something other than 2.\n\n  Assumes,\n  - First dimension of inputs is batch_size\n  - Convolution filter is applied to all input channels.\n\n  Args:\n    inputs: Tensor of shape [batch_size, ..spatial_image_shape..,\n      ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().\n    filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().\n    padding: string. Padding method. One of \"VALID\", \"SAME\".\n    strides: None or list of ints. Strides along spatial dimensions.\n    dilation_rate: None or list of ints. Dilation along spatial dimensions.\n    name: None or str. Name of Op.\n    data_format: None or str. Format of data.\n\n  Returns:\n    Tensor of shape [batch_size, ..spatial_image_shape..,\n      ..spatial_filter_shape.., in_channels]\n\n  Raises:\n    ValueError: If data_format does not put channel last.\n    ValueError: If inputs and filter disagree on in_channels.\n  \"\"\"\n  if not is_data_format_channel_last(data_format):\n    raise ValueError(\"Channel must be last dimension.\")\n  with tf.name_scope(name, \"extract_convolution_patches\",\n                     [inputs, filter_shape, padding, strides, dilation_rate]):\n    batch_size = inputs.shape.as_list()[0]\n    in_channels = inputs.shape.as_list()[-1]\n\n    # filter_shape = spatial_filter_shape + [in_channels, out_channels]\n    spatial_filter_shape = filter_shape[:-2]\n    if in_channels != filter_shape[-2]:\n      raise ValueError(\"inputs and filter_shape must agree on in_channels.\")\n\n    # Map each input feature to a location in the output.\n    out_channels = np.prod(spatial_filter_shape) * in_channels\n    filters = tf.eye(out_channels, dtype=inputs.dtype)\n    filters = tf.reshape(\n        filters,\n        list(spatial_filter_shape) + [in_channels, out_channels])\n\n    if strides is not None and len(strides) == len(inputs.shape):\n      strides = strides[1:-1]  # remove batch and channel dimension\n\n    if dilation_rate is not None and len(dilation_rate) == len(inputs.shape):\n      dilation_rate = dilation_rate[1:-1]  # remove batch and channel dimension\n\n    result = tf.nn.convolution(\n        inputs,\n        filters,\n        padding=padding,\n        strides=strides,\n        dilation_rate=dilation_rate)\n    spatial_output_shape = result.shape.as_list()[1:-1]\n    result = tf.reshape(result, [batch_size or -1] + spatial_output_shape +\n                        list(spatial_filter_shape) + [in_channels])\n\n    return result\n\n\ndef extract_pointwise_conv2d_patches(inputs,\n                                     filter_shape,\n                                     name=None,\n                                     data_format=None):\n  \"\"\"Extract patches for a 1x1 conv2d.\n\n  Args:\n    inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].\n    filter_shape: List of 4 ints. Shape of filter to apply with conv2d()\n    name: None or str. Name for Op.\n    data_format: None or str. Format for data. See 'data_format' in\n      tf.nn.conv2d() for details.\n\n  Returns:\n    Tensor of shape [batch_size, ..spatial_input_shape..,\n    ..spatial_filter_shape.., in_channels]\n\n  Raises:\n    ValueError: if inputs is not 4-D.\n    ValueError: if filter_shape is not [1, 1, ?, ?]\n    ValueError: if data_format is not channels-last.\n  \"\"\"\n  if inputs.shape.ndims != 4:\n    raise ValueError(\"inputs must have 4 dims.\")\n  if len(filter_shape) != 4:\n    raise ValueError(\"filter_shape must have 4 dims.\")\n  if filter_shape[0] != 1 or filter_shape[1] != 1:\n    raise ValueError(\"filter_shape must have shape 1 along spatial dimensions.\")\n  if not is_data_format_channel_last(data_format):\n    raise ValueError(\"data_format must be channels last.\")\n  with tf.name_scope(name, \"extract_pointwise_conv2d_patches\",\n                     [inputs, filter_shape]):\n    ksizes = [1, 1, 1, 1]  # Spatial shape is 1x1.\n    strides = [1, 1, 1, 1]  # Operate on all pixels.\n    rates = [1, 1, 1, 1]  # Dilation has no meaning with spatial shape = 1.\n    padding = \"VALID\"  # Doesn't matter.\n    result = tf.extract_image_patches(inputs, ksizes, strides, rates, padding)\n\n    batch_size, input_height, input_width, in_channels = inputs.shape.as_list()\n    filter_height, filter_width, in_channels, _ = filter_shape\n    return tf.reshape(result, [\n        batch_size, input_height, input_width, filter_height, filter_width,\n        in_channels\n    ])\n\n\ndef is_data_format_channel_last(data_format):\n  \"\"\"True if data_format puts channel last.\"\"\"\n  if data_format is None:\n    return True\n  return data_format.endswith(\"C\")\n\n\ndef matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False):  # pylint: disable=invalid-name\n  \"\"\"Computes matmul(A, B) where A is sparse, B is dense.\n\n  Args:\n    A: tf.IndexedSlices with dense shape [m, n].\n    B: tf.Tensor with shape [n, k].\n    name: str. Name of op.\n    transpose_a: Bool. If true we transpose A before multiplying it by B.\n      (Default: False)\n    transpose_b: Bool. If true we transpose B before multiplying it by A.\n      (Default: False)\n\n  Returns:\n    tf.IndexedSlices resulting from matmul(A, B).\n\n  Raises:\n    ValueError: If A doesn't represent a matrix.\n    ValueError: If B is not rank-2.\n  \"\"\"\n  with tf.name_scope(name, \"matmul_sparse_dense\", [A, B]):\n    if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:\n      raise ValueError(\"A must represent a matrix. Found: %s.\" % A)\n    if B.shape.ndims != 2:\n      raise ValueError(\"B must be a matrix.\")\n    new_values = tf.matmul(\n        A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)\n    return tf.IndexedSlices(\n        new_values,\n        A.indices,\n        dense_shape=tf.stack([A.dense_shape[0], new_values.shape[1]]))\n\n\ndef matmul_diag_sparse(A_diag, B, name=None):  # pylint: disable=invalid-name\n  \"\"\"Computes matmul(A, B) where A is a diagonal matrix, B is sparse.\n\n  Args:\n    A_diag: diagonal entries of matrix A of shape [m, m].\n    B: tf.IndexedSlices. Represents matrix of shape [m, n].\n    name: str. Name of op.\n\n  Returns:\n    tf.IndexedSlices resulting from matmul(A, B).\n\n  Raises:\n    ValueError: If A_diag is not rank-1.\n    ValueError: If B doesn't represent a matrix.\n  \"\"\"\n  with tf.name_scope(name, \"matmul_diag_sparse\", [A_diag, B]):\n    A_diag = tf.convert_to_tensor(A_diag)\n    if A_diag.shape.ndims != 1:\n      raise ValueError(\"A_diag must be a rank-1 Tensor.\")\n    if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:\n      raise ValueError(\"B must represent a matrix. Found: %s.\" % B)\n    a = tf.gather(A_diag, B.indices)\n    a = tf.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))\n    return tf.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)\n\n\nclass AccumulatorVariable(object):\n  \"\"\"A simple abstraction to accumulate data that we want to average.\n\n  Basically this variable accumulates data across multiple inputs, and\n  then returns the average of these contributes on command.  This accumulation\n  can be reset by the user at any point.\n  \"\"\"\n\n  def __init__(self, name, shape, dtype):\n    \"\"\"Constructs a new `AccumulatorVariable`.\n\n    Args:\n      name: `string`. Scope for the variables.\n      shape: shape of the variable.\n      dtype: dtype of the variable.\n    \"\"\"\n    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):\n      self._acc_var = tf.get_variable(\n          \"acc_var\",\n          shape=shape,\n          dtype=dtype,\n          initializer=tf.zeros_initializer(),\n          trainable=False,\n          use_resource=True)\n\n      # We may be able to make give this a VariableAggregation of\n      # ONLY_FIRST_REPLICA, because we only add 1 or reset it to 0 (it does not\n      # rely on per-replica values). If we do, we can update this in a per-\n      # replica context instead of the cross-replica context. This may improve\n      # efficiency when using a VariableSynchronization of ON_READ.\n      self._counter = tf.get_variable(\n          \"counter\",\n          shape=(),\n          dtype=tf.float32,\n          initializer=tf.zeros_initializer(),\n          trainable=False,\n          use_resource=True)\n\n  def accumulate(self, value):\n    \"\"\"Adds `value` to the accumulated data.\"\"\"\n    inc_counter_op = smart_assign(self._counter, 1.0, assign_fn=tf.assign_add)\n    acc_op = smart_assign(self._acc_var, value, assign_fn=tf.assign_add)\n    return tf.group(inc_counter_op, acc_op)\n\n  @property\n  def value(self):\n    \"\"\"Returns the average of the accumulated values since the last reset.\"\"\"\n    return self._acc_var / tf.cast(self._counter, self._acc_var.dtype)\n\n  def read_value_and_reset(self):\n    \"\"\"Same as `value` property but resets after the data is read.\"\"\"\n    value = self.value\n    with tf.control_dependencies([value]):\n      with tf.control_dependencies([self.reset()]):\n        return tf.identity(value)\n\n  def reset(self):\n    \"\"\"Resets the accumulated data to zero.\"\"\"\n    var_reset_op = smart_assign(\n        self._acc_var, tf.zeros(self._acc_var.shape, dtype=self._acc_var.dtype))\n    counter_reset_op = smart_assign(self._counter,\n                                    tf.constant(0.0, dtype=tf.float32))\n\n    return tf.group(var_reset_op, counter_reset_op)\n\n\nclass PartitionedTensor(object):\n  \"\"\"A Tensor partitioned across its 0-th dimension.\"\"\"\n\n  def __init__(self, tensors):\n    \"\"\"Initializes PartitionedTensor.\n\n    Args:\n      tensors: List of Tensors. All Tensors must agree on shape (excepting\n        batch dimension) and dtype.\n\n    Raises:\n      ValueError: If 'tensors' has length zero.\n      ValueError: if contents of 'tensors' don't agree on shape or dtype.\n    \"\"\"\n    if not tensors:\n      raise ValueError(\"tensors must be a list of 1+ Tensors.\")\n\n    dtype = tensors[0].dtype\n    if not all(tensor.dtype == dtype for tensor in tensors):\n      raise ValueError(\n          \"all tensors must have the same dtype. The tensors are {}\".format(\n              tensors))\n\n    shape = tensors[0].shape[1:]\n    if not all(tensor.shape[1:] == shape for tensor in tensors):\n      raise ValueError(\"All tensors must have shape = %s (excluding batch \"\n                       \"dimension).\" % shape)\n\n    one_hot_depth = getattr(tensors[0], \"one_hot_depth\", None)\n    if not all(\n        getattr(tensor, \"one_hot_depth\", None) == one_hot_depth\n        for tensor in tensors):\n      raise ValueError(\n          \"All tensors must have one_hot_depth {}\".format(one_hot_depth))\n\n    self.tensors = tensors\n\n  @property\n  def shape(self):\n    feature_shape = self.tensors[0].shape[1:]\n    batch_size = sum([tensor.shape[0] for tensor in self.tensors],\n                     tf.Dimension(0))\n    return tf.TensorShape([batch_size]).concatenate(feature_shape)\n\n  def get_shape(self):\n    return self.shape\n\n  @property\n  def dtype(self):\n    return self.tensors[0].dtype\n\n  @property\n  def one_hot_depth(self):\n    return getattr(self.tensors[0], \"one_hot_depth\", None)\n\n  def __str__(self):\n    return \"PartitionedTensor([%s, ...], dtype=%s, shape=%s)\" % (\n        self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))\n\n  def __hash__(self):\n    return hash(tuple(self.tensors))\n\n  def __eq__(self, other):\n    if not isinstance(other, PartitionedTensor):\n      return False\n    return self.tensors == other.tensors\n\n  def __ne__(self, other):\n    return not self == other  # pylint: disable=g-comparison-negation\n\n  def __getitem__(self, key):\n    return self.as_tensor()[key]\n\n  def as_tensor(self, dtype=None, name=None, as_ref=False):\n    with tf.name_scope(name, \"PartitionedTensor.as_tensor\", self.tensors):\n      assert not as_ref\n      assert dtype in [None, self.dtype]\n      return tf.concat(self.tensors, axis=0)\n\n  @property\n  def device(self):\n    # PartitionedTensors in general do not live on a single device.  If the\n    # device cannot be determined unambiguously this property will return None.\n    device = self.tensors[0].device\n    if all(tensor.device == device for tensor in self.tensors):\n      return device\n    return None\n\n\ntf.register_tensor_conversion_function(\n    PartitionedTensor,\n    lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))\n\n\n# TODO(b/69623235): Add a function for finding tensors that share gradients\n# to eliminate redundant fisher factor computations.\n\n\ndef _check_match_lists_of_pairs(list1, list2):\n  for (_, var1), (_, var2) in zip(list1, list2):\n    if var1 is not var2:\n      raise ValueError(\"The variables referenced by the two arguments \"\n                       \"must match.\")\n\n\ndef sprod(scalar, list_):\n  # Product of scalar with list of items.\n  return tuple(scalar*item for item in list_)\n\n\ndef sprod_p(scalar, list_):\n  # Product of scalar with list of (item, var) pairs.\n  return tuple((scalar*item, var) for (item, var) in list_)\n\n\ndef sum_(list1, list2):\n  # Element-wise sum of lists of tensors.\n  return tuple(item1 + item2 for item1, item2 in zip(list1, list2))\n\n\ndef sum_p(list1, list2):\n  # Element-wise sum of lists of (tensor, var) pairs.\n  _check_match_lists_of_pairs(list1, list2)\n  return tuple((item1 + item2, var1)\n               for (item1, var1), (item2, var2) in zip(list1, list2))\n\n\ndef ip(list1, list2):\n  # Inner product of lists of tensors.\n  return tf.add_n(tuple(tf.reduce_sum(tensor1 * tensor2)\n                        for tensor1, tensor2 in zip(list1, list2)))\n\n\ndef ip_p(list1, list2):\n  # Inner product of lists of (tensor, var) pairs.\n  _check_match_lists_of_pairs(list1, list2)\n\n  return ip(tuple(tensor for (tensor, _) in list1),\n            tuple(tensor for (tensor, _) in list2))\n\n\ndef assert_variables_match_pairs_list(a_and_vars,\n                                      b_and_vars,\n                                      error_message=None):\n  \"\"\"Assert the variables in two lists of (tensor, var) pairs are the same.\n\n  Args:\n    a_and_vars: a list of (tensor, variable) pairs.\n    b_and_vars: a list of (tensor, variable) pairs.\n    error_message: an optional string prepended to the error message.\n\n  Raises:\n    ValueError: if any variables in the input pair lists are not the same.\n  \"\"\"\n  _, a_variables = zip(*a_and_vars)\n  _, b_variables = zip(*b_and_vars)\n  variable_mismatch_indices = []\n  for vi, (a_var, b_var) in enumerate(zip(a_variables, b_variables)):\n    if a_var is not b_var:\n      variable_mismatch_indices.append(vi)\n\n  if variable_mismatch_indices:\n    mismatch_indices_str = \", \".join(map(str, variable_mismatch_indices))\n    a_variables_str = \", \".join(map(str, a_variables))\n    b_variables_str = \", \".join(map(str, b_variables))\n    error_str = (\"Mismatch on variable lists at indices {}.\\n\\nFirst list:  {}\"\n                 \"\\n\\nSecond list:  {} \\n\").format(\n        mismatch_indices_str, a_variables_str, b_variables_str)\n    if error_message:\n      error_str = \"{} {}\".format(error_message, error_str)\n    raise ValueError(error_str)\n\n\ndef multiline_print(lists):\n  \"\"\"Prints multiple lines of output using tf.print.\"\"\"\n\n  combined_list = []\n  combined_list += lists[0]\n\n  # We prepend newline characters to strings at the start of lines to avoid\n  # the ugly space intendations that tf.print's behavior of separating\n  # everything with a space would otherwise cause.\n  for item in lists[1:]:\n    if isinstance(item[0], str):\n      combined_list += ((\"\\n\" + item[0],) + item[1:])\n    else:\n      combined_list += ((\"\\n\",) + item)\n\n  return tf.print(*combined_list)\n\n\ndef get_shape(tensor):\n  \"\"\"Returns list of dimensions using ints only for statically known ones.\"\"\"\n\n  if tensor.shape.dims is None:\n    raise ValueError(\"Unknown rank for tensor {}.\".format(tensor))\n\n  static_shape = tensor.shape.as_list()\n  dynamic_shape = tf.shape(tensor)\n  return tuple(elt if elt is not None else dynamic_shape[idx]\n               for idx, elt in enumerate(static_shape))\n\n\ndef cls_name(obj):\n  return obj.__class__.__name__\n\n\ndef is_reference_variable(x):\n  return ((isinstance(x, tf.Variable)\n           and not resource_variable_ops.is_resource_variable(x))\n          or hasattr(x, \"_should_act_as_ref_variable\"))\n\n\nclass MovingAverageVariable(object):\n  \"\"\"A variable updated using weighted moving averages.\n\n  Note that to implement a traditional decaying exponential average one should\n  use a decay value smaller than 1.0 (e.g. 0.9), and set weight = 1.0 - decay.\n  Doing this and setting normalize_value to True will implement \"zero-debiased\"\n  decayed averages.\n  \"\"\"\n\n  def __init__(self, name, shape, dtype, initializer=tf.zeros_initializer(),\n               normalize_value=True):\n    \"\"\"Constructs a new `MovingAverageVariable`.\n\n    Args:\n      name: `string`. Scope for the variables.\n      shape: shape of the variable.\n      dtype: dtype of the variable.\n      initializer: initializer for the variable (see tf.get_variable). Should\n        be tf.zeros_initializer() unless you know what you are doing.\n        (Default: tf.zeros_initializer())\n      normalize_value: bool. If True we normalize the value property by the\n        total weight (which will be subject to decay). (Default: True)\n    \"\"\"\n    self._normalize_value = normalize_value\n\n    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):\n      self._var = tf.get_variable(\n          \"var\",\n          shape=shape,\n          dtype=dtype,\n          initializer=initializer,\n          trainable=False,\n          use_resource=True)\n\n      self._total_weight = tf.get_variable(\n          \"total_weight\",\n          shape=(),\n          dtype=dtype,\n          initializer=tf.zeros_initializer(),\n          trainable=False,\n          use_resource=True)\n\n  @property\n  def dtype(self):\n    return self._var.dtype.base_dtype\n\n  @property\n  def value(self):\n    if self._normalize_value:\n      return self._var / self._total_weight\n    else:\n      return tf.identity(self._var)\n\n  def add_to_average(self, value, decay=1.0, weight=1.0):\n    \"\"\"Add a value into the moving average.\n\n    Args:\n      value: a Tensor matching the shape and dtype that was passed to the\n        constructor.\n      decay: float or 0D Tensor. The current value is multiplied by this before\n        the value is added, as is the total accumulated weight. (Default: 1.0)\n      weight: float or 0D Tensor. The value being added is multiplied by this.\n        Also this is added to the total accumulated weight. (Default: 1.0)\n    \"\"\"\n    decay = tf.cast(decay, dtype=self.dtype)\n    weight = tf.cast(weight, dtype=self.dtype)\n\n    update_var = smart_assign(self._var, decay * self._var + weight * value)\n\n    update_total_weight = smart_assign(self._total_weight,\n                                       decay * self._total_weight + weight)\n\n    return tf.group(update_var, update_total_weight)\n\n  def reset(self):\n    return tf.group(\n        smart_assign(self._var, tf.zeros_like(self._var)),\n        smart_assign(self._total_weight, tf.zeros_like(self._total_weight))\n        )\n\n\ndef num_conv_locations(input_shape, filter_shape, strides, padding):\n  \"\"\"Returns the number of spatial locations a conv kernel is applied to.\n\n  Args:\n    input_shape: List of ints representing shape of inputs to\n      tf.nn.convolution().\n    filter_shape: List of ints representing shape of filter to\n      tf.nn.convolution().\n    strides: List of ints representing strides along spatial dimensions as\n      passed in to tf.nn.convolution().\n    padding: string representing the padding method, either 'VALID' or 'SAME'.\n\n  Returns:\n    A scalar |T| denoting the number of spatial locations for the Conv layer.\n\n  Raises:\n    ValueError: If input_shape, filter_shape don't represent a 1-D or 2-D\n      convolution.\n  \"\"\"\n  if len(input_shape) != 4 and len(input_shape) != 3:\n    raise ValueError(\"input_shape must be length 4, corresponding to a Conv2D,\"\n                     \" or length 3, corresponding to a Conv1D.\")\n  if len(input_shape) != len(filter_shape):\n    raise ValueError(\"Inconsistent number of dimensions between input and \"\n                     \"filter for convolution\")\n\n  if strides is None:\n    if len(input_shape) == 4:\n      strides = [1, 1, 1, 1]\n    else:\n      strides = [1, 1, 1]\n\n  # Use negative integer division to implement 'rounding up'.\n  # Formula for convolution shape taken from:\n  # http://machinelearninguru.com/computer_vision/basics/convolution/convolution_layer.html\n  if len(input_shape) == 3:\n    if padding is not None and padding.lower() == \"valid\":\n      out_width = -(-(input_shape[1] - filter_shape[0] + 1) // strides[1])\n    else:\n      out_width = -(-input_shape[1] // strides[1])\n\n    return out_width\n  else:\n    if padding is not None and padding.lower() == \"valid\":\n      out_height = -(-(input_shape[1] - filter_shape[0] + 1) // strides[1])\n      out_width = -(-(input_shape[2] - filter_shape[1] + 1) // strides[2])\n    else:\n      out_height = -(-input_shape[1] // strides[1])\n      out_width = -(-input_shape[2] // strides[2])\n\n    return out_height * out_width\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Install kfac.\"\"\"\n\nfrom setuptools import find_packages\nfrom setuptools import setup\n\nsetup(\n    name='kfac',\n    version='0.2.4',\n    description='K-FAC for TensorFlow',\n    author='Google Inc.',\n    author_email='no-reply@google.com',\n    url='http://github.com/tensorflow/kfac',\n    license='Apache 2.0',\n    packages=find_packages(exclude=[\n        'kfac.examples.*',\n        'kfac.python.kernel_tests.*',\n    ]),\n    install_requires=[\n        'numpy',\n        'six',\n        'tensorflow-probability==0.8',\n        'h5py<3',\n    ],\n    extras_require={\n        # It's possible that you might need to put tensorflow<2.0 here:\n        'tensorflow': ['tensorflow>=1.14'],\n        # It's possible that you might need to put tensorflow-gpu<2.0 here:\n        'tensorflow_gpu': ['tensorflow-gpu>=1.14'],\n        # dm-sonnet<2.0 will force tensorflow<2.0 in the tests:\n        'tests': ['pytest', 'dm-sonnet<2.0', 'numpy<1.20'],\n    },\n    classifiers=[\n        'Development Status :: 4 - Beta',\n        'Intended Audience :: Developers',\n        'Intended Audience :: Science/Research',\n        'License :: OSI Approved :: Apache Software License',\n        'Topic :: Scientific/Engineering :: Artificial Intelligence',\n    ],\n    keywords='tensorflow machine learning',\n)\n"
  }
]