Repository: tensorflow/kfac Branch: master Commit: ddad6375bbde Files: 69 Total size: 989.3 KB Directory structure: gitextract_xp2vjhqo/ ├── .travis.yml ├── AUTHORS ├── LICENSE ├── README.md ├── docs/ │ ├── applications.md │ ├── contact.md │ ├── examples/ │ │ ├── auto_damp.md │ │ ├── convolutional.md │ │ ├── distributed_training.md │ │ └── parameters.md │ ├── index.md │ ├── papers.md │ └── sitemap.md ├── kfac/ │ ├── __init__.py │ ├── examples/ │ │ ├── __init__.py │ │ ├── autoencoder_mnist.py │ │ ├── autoencoder_mnist_tpu_estimator.py │ │ ├── autoencoder_mnist_tpu_strategy.py │ │ ├── classifier_mnist.py │ │ ├── classifier_mnist_tpu_estimator.py │ │ ├── convnet.py │ │ ├── keras/ │ │ │ ├── KFAC_vs_Adam_Experiment.md │ │ │ ├── KFAC_vs_Adam_on_CIFAR10.ipynb │ │ │ └── KFAC_vs_Adam_on_CIFAR10_TPU.ipynb │ │ ├── mnist.py │ │ └── rnn_mnist.py │ └── python/ │ ├── __init__.py │ ├── keras/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── optimizers.py │ │ ├── saving_utils.py │ │ └── utils.py │ ├── kernel_tests/ │ │ ├── data_reader_test.py │ │ ├── estimator_test.py │ │ ├── graph_search_test.py │ │ ├── keras_callbacks_test.py │ │ ├── keras_optimizers_test.py │ │ ├── keras_saving_utils_test.py │ │ ├── keras_utils_test.py │ │ ├── layer_collection_test.py │ │ ├── loss_functions_test.py │ │ ├── op_queue_test.py │ │ ├── optimizer_test.py │ │ ├── periodic_inv_cov_update_kfac_opt_test.py │ │ └── utils_test.py │ └── ops/ │ ├── __init__.py │ ├── curvature_matrix_vector_products.py │ ├── estimator.py │ ├── fisher_blocks.py │ ├── fisher_factors.py │ ├── kfac_utils/ │ │ ├── __init__.py │ │ ├── async_inv_cov_update_kfac_opt.py │ │ ├── data_reader.py │ │ ├── data_reader_alt.py │ │ └── periodic_inv_cov_update_kfac_opt.py │ ├── layer_collection.py │ ├── linear_operator.py │ ├── loss_functions.py │ ├── op_queue.py │ ├── optimizer.py │ ├── placement.py │ ├── tensormatch/ │ │ ├── __init__.py │ │ ├── graph_matcher.py │ │ ├── graph_patterns.py │ │ ├── graph_search.py │ │ └── tensorflow_graph_util.py │ └── utils.py └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .travis.yml ================================================ language: python python: - "3.6" env: matrix: - TF_VERSION="1.15" install: - pip install -q "tensorflow==$TF_VERSION" - pip install -q .[tests] # Make sure we have the latest version of numpy - avoid problems we were # seeing with Python 3 - pip install -q -U numpy script: # Check import - python -c "import kfac; print(kfac.LayerCollection.__name__)" # Run tests - pytest git: depth: 3 ================================================ FILE: AUTHORS ================================================ # This is the official list of TensorFlow authors for copyright purposes. # This file is distinct from the CONTRIBUTORS files. # See the latter for an explanation. # Names should be added to this file as: # Name or Organization # The email address is not required for organizations. Google Inc. ================================================ FILE: LICENSE ================================================ Copyright 2019 The TensorFlow Authors. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2019, The TensorFlow Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # K-FAC: Kronecker-Factored Approximate Curvature [![Travis](https://img.shields.io/travis/tensorflow/kfac.svg)](https://travis-ci.org/tensorflow/kfac) **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. [kfac-paper]: https://arxiv.org/abs/1503.05671 ## Installation `kfac` is compatible with Python 2 and 3 and can be installed directly via `pip`, ```shell # Assumes tensorflow or tensorflow-gpu installed $ pip install kfac # Installs with tensorflow-gpu requirement $ pip install 'kfac[tensorflow_gpu]' # Installs with tensorflow (cpu) requirement $ pip install 'kfac[tensorflow]' ``` ## KFAC DOCS Please check [KFAC docs][kfac_docs] for a detailed description with examples of how to use KFAC. Check the [Keras KFAC docs][keras_docs] for information on using KFAC with Keras. [kfac_docs]: https://github.com/tensorflow/kfac/tree/master/docs/index.md [keras_docs]: https://github.com/tensorflow/kfac/tree/master/kfac/python/keras/README.md ================================================ FILE: docs/applications.md ================================================ Coming Soon.. ================================================ FILE: docs/contact.md ================================================ Topic | Contact -------------------- | --------------------- **Questions** | kfac-users@google.com **Development Team** | kfac-dev@google.com Primary contacts: * James Martens (jamesmartens@google.com) Contributors (past and present): * Alok Aggarwal (aloka@google.com) * Daniel Duckworth (duckworthd@google.com) * David Pfau (pfau@google.com) * Dominik Grewe (dominikg@google.com) * Guodong Zhang (gdzhang@google.com) * James Keeling (jtkeeling@google.com) * James Martens (jamesmartens@google.com) * Jimmy Ba (jba@cs.toronto.edu) * Lala Li (lala@google.com) * Matthew Johnson (mattjj@google.com) * Nicholas Vadivelu (nvadivelu@google.com) * Noah Siegel (siegeln@google.com) * Olga Wichrowskaa (olganw@google.com) * Rishabh Kabra (rkabra@google.com) * Roger Grosse (rgrosse@google.com) * Soham De (sohamde@google.com) * Tamas Berghammer (tberghammer@google.com) * Vikram Tankasali (tvikram@google.com) * Zachary Nado (znado@google.com) ================================================ FILE: docs/examples/auto_damp.md ================================================ # Automatic tuning of damping parameter. ## Table of Contents * [1. Cached Reader](#1-cached-reader) * [2. Build optimizer and set damping parameters](#2-build-optimizer-and-set-damping-parameters) * [TIPS:](#tips)
The [KFAC damping parameter][kfac_damp] can be auto tuned using Levenberg-Marquardt (LM) algorithm. For a detailed description of the algorithm refer to `Section 6` of the [KFAC Paper][kfac_paper]. Note this is still a heuristic and may not always produce optimal results. It can be better or worse than a carefully tuned fixed value, depending on the problem. [kfac_paper]: https://arxiv.org/pdf/1503.05671.pdf [kfac_damp]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md **Example code**: https://github.com/tensorflow/kfac/tree/master/kfac/examples/autoencoder_mnist.py Using this method to auto tune damping requires changes to the basic KFAC training script, which are described below. We only highlight additional steps required vs training with a fixed damping value (as in the [Convnet example][convexamplesec]) [convexamplesec]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md ## 1. Cached Reader Wrap the dataset into `CachedReader`. This allows us to access previous batch of data. ```python cached_reader = data_reader.CachedDataReader(dataset, max_batch_size) minibatch = cached_reader(batch_size) ``` ## 2. Build optimizer and set damping parameters ```python optimizer = kfac.PeriodicInvCovUpdateKfacOpt( learning_rate=1.0, damping=150., momentum=0.95, layer_collection=layer_collection, batch_size=batch_size, adapt_damping=True, prev_train_batch=cached_reader.cached_batch, is_chief=True, loss_fn=loss_fn, damping_adaptation_decay=0.95, damping_adaptation_interval=FLAGS.damping_adaptation_interval, ) train_op = optimizer.minimize(loss, global_step=global_step) ``` ## TIPS: 1. Damping can also be tuned using Population based training ([PBT][PBT_link]). In our observations PBT works on par with auto tuning using LM algorithm, although is obviously more computationally expensive. However if you are already doing PBT for other hyperparams then consider tuning damping using PBT as well. [PBT_link]: https://arxiv.org/abs/1711.09846 ================================================ FILE: docs/examples/convolutional.md ================================================ # Convolutional ## Table of Contents * [Build the Model](#build-the-model) * [Register the layers and loss](#register-the-layers-and-loss) * [Build the optimizer](#build-the-optimizer) * [Fit the model](#fit-the-model) * [TIPS](#tips)
K-FAC needs to know about the structure of your model in order to effectively optimize it. In particular, it needs to know about: 1. Each convolutional and feed forward layer's inputs and outputs. 1. All of the model parameters. 1. The type of the loss function and its inputs. Let's explore how we can use K-FAC to solve digit classification with MNIST using a simple convolutional model. In the following example we will illustrate how to use `PeriodicInvCovUpdateOpt` which is a subclass of `KfacOptimizer`. `PeriodicInvCovUpdateOpt` handles placement and execution of covariance and inverse ops. We will also illustrate how to register the layers both manually and automatically using the graph scanner. **Code**: https://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet_mnist_single_main.py ## Build the Model First, we begin by defining a model. In this case, we'll load MNIST and construct a 5-layer ConvNet. The model has 2 Conv/MaxPool pairs and a final linear layer. If we are registering the layers manually we need to keep the inputs and outputs and parameters (weights & bias) around, which is illustrated here. ```python # Load a dataset. examples, labels = mnist.load_mnist( data_dir, num_epochs=num_epochs, batch_size=128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. pre0, act0, params0 = conv_layer( layer_id=0, inputs=examples, kernel_size=5, out_channels=16) act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) pre2, act2, params2 = conv_layer( layer_id=2, inputs=act1, kernel_size=5, out_channels=16) act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) logits, params4 = linear_layer( layer_id=4, inputs=flat_act3, output_size=num_labels) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits)) accuracy = tf.reduce_mean( tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) ``` ## Register the layers and loss `layer_collection.auto_register_layers` automatically registers all the layers for typical/standard models. However one must still manually register the loss function. In the case of cross-entropy loss functions on softmaxes this amounts to calling `layer_collection.register_softmax_cross_entropy_loss` with the logits as an argument. Note that the inputs/outputs of non-parameterized layers such as max pooling and reshaping _do not_ need to be registered. ```python # Register parameters with graph_search. tf.logging.info("Building KFAC Optimizer.") layer_collection = kfac.LayerCollection() layer_collection.register_softmax_cross_entropy_loss(logits) # Set the layer at params0 to use a diagonal approximation # instead of default Kronecker factor based approximation. layer_collection.define_linked_parameters( params0, approximation=layer_collection.APPROX_DIAGONAL_NAME) layer_collection.auto_register_layers() ``` In the example above we demonstrate how to use a non-default Fisher approximation (diagonal) for one of the conv layers. (The default is usually Kronecker-factored.) This is done by calling `layer_collection.define_linked_parameters`, which identifies the given variables as being part of a particular layer, and sets the approximation that is to be used for that layer. Any registrations performed later, whether done by the graph scanner or performed manually by the user, will use this approximation (unless overridden by the `approx` argument to the registration function). Layers can also be registered manually. This is required for types of layers that the automatic graph scanner doesn't recognize. Note that One can also use a combination of manual and automatic registration by calling `auto_register_layers()` after performing some manual registration. Any layers registered manually before will be ignored by the scanner. We register each layer's inputs, outputs, and parameters with an instance of `LayerCollection`. For convolution layers, we use `register_conv2d`. For fully connected (or linear) layers, `register_fully_connected`. ```python # Register parameters manually. tf.logging.info("Building KFAC Optimizer.") layer_collection = kfac.LayerCollection() layer_collection.register_softmax_cross_entropy_loss(logits) layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, pre0, approx=kfac_ff.APPROX_DIAGONAL_NAME) layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) layer_collection.register_fully_connected(params4, flat_act3, logits) ``` In this example we demonstrate how to use a non-default Fisher approximation (diagonal) for one of the layers. (The default is usually Kronecker-factored.) This is done by passing `approx=kfac_ff.APPROX_DIAGONAL_NAME` to the registration function `layer_collection.register_conv2d`. Note that if One has already used `define_linked_parameters` to set the approximation then it is not required to specify it again via the `approx` argument. ## Build the optimizer Finally, we instantiate the optimizer. In addition to the `learning_rate` and `momentum`, the optimizer has 2 additional hyperparameters, 1. `cov_ema_decay`: Check [hyper parameters][hyper_params] section for more details. 1. `damping`: This is a critical parameter and needs to be tuned for good performance. Check [hyper parameters][hyper_params] section for more details. [hyper_params]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md ```python # Train with K-FAC. global_step = tf.train.get_or_create_global_step() optimizer = kfac.PeriodicInvCovUpdateKfacOpt( learning_rate=0.0001, damping=0.001, momentum=0.9, cov_ema_decay=0.95, invert_every=10, cov_update_every=1, layer_collection=layer_collection) train_op = optimizer.minimize(loss, global_step=global_step) ``` ## Fit the model Optimizing with KFAC is similar to using a standard optimizer, where there is an "update op" that computes and applies the update to the model's parameters. However, KFAC introduces two additional sets of ops that must also be executed as part of the algorithm (although not necessarily at every iteration). These are called the "covariance update ops" and "inverse update ops", respectively. The covariance update ops update the various "covariance" matrices used to compute the Fisher block approximations for the layers. The inverse update ops meanwhile are responsible for computing inverses of the approximate Fisher blocks (using algorithms that exploit their special structure). `PeriodicInvCovUpdateKfacOpt`, which is a subclass of `KfacOptimizer` class, folds these extra ops into the standard update op, so that they execute periodically on certain iterations, according to the `cov_update_every` and `invert_every` arguments. Users seeking more fine-grained control of the timing and placement of the ops can use the base `KfacOptimizer` class. ```python with tf.train.MonitoredTrainingSession() as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _, _ = sess.run( [global_step, loss, accuracy, train_op]) ``` ## TIPS 1. Check the [hyper params tuning][hp_tune] section for more details on tuning various KFAC parameters. [hp_tune]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md [mlp]: https://en.wikipedia.org/wiki/Multilayer_perceptron [preconditioner]: https://en.wikipedia.org/wiki/Preconditioner#Preconditioning_in_optimization ================================================ FILE: docs/examples/distributed_training.md ================================================ # Distributed Training ## Table of Contents * [Register the layers](#register-the-layers) * [Build the optimizer](#build-the-optimizer) * [Fit the model](#fit-the-model) * [TIPS](#tips)
This example showcases how to use K-FAC in a distributed setting using `SyncReplicas` optimizer. If you are interested in using `tf.distribute.Strategy`, we support `MirroredStrategy` and `TPUStrategy`, with an example for `TPUStrategy` [here][tpu_strategy_example]. While most methods benefit from increased compute, K-FAC particularly shines as the number of workers (and, in turn, batch size) increases. [here][https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb] **Note:** This tutorial extends the single-machine [Convolutional example][conv_ex] to distributed training. It is highly recommended you read that first, as shared bits are omitted below! [conv_ex]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md **Note:** This tutorial expects you to be familiar with distributed training. Check out https://www.tensorflow.org/deploy/distributed if this is new to you. **Example code**: https://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet_mnist_distributed_main.py ## Build the Model When training on a single machine, one doesn't need to think about which "device" a variable is placed on (there's only 1 to choose from!). In a distributed setting, variables live on ["Parameter Servers"][parameter-servers]. Placing a variable on a parameter server is as simple as using `tf.train.replica_device_setter()`, which is illustrated in the below code. ```python with tf.device(tf.train.replica_device_setter(num_ps_tasks)): pre0, act0, w0, b0 = conv_layer( layer_id=0, inputs=examples, kernel_size=5, out_channels=16) act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) ... ``` [parameter-servers]: https://www.tensorflow.org/deploy/distributed ## Register the layers Layer registration is identical to the single-machine case. See ["Register the layers"][register-layers-conv] in the Convolutional example for details. [register-layers-conv]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md?#register-the-layers-and-loss ## Build the optimizer Like the model itself, the K-FAC optimizer also creates variables. Don't forget to wrap it in a similar `replica_device_setter()` too! ```python with tf.device(tf.train.replica_device_setter(num_ps_tasks)): ... optimizer = opt.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, momentum=0.9) ... ``` ## Fit the model When training on a single-machine, a single training loop is responsible for executing all of K-FAC's training operations: updating weights, updating statistics, and inverting the preconditioner matrix. As all of the work happens on a single machine, one stands little to gain by parallelization. There are different strategies of parallelizing the gradient, covariance and inverse computation across workers in a distributed setting. We will illustrate here two such strategies that work specifically with `SyncReplicas` optimizer for distributed training. The first strategy for distributed training is to compute gradient in a distributed fashion across all the workers, but have the inverse and covariance ops executed only on the chief worker. **Code**: https://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet.py ```python optimizer = opt.KfacOptimizer(...) sync_optimizer = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() tf.logging.info("Starting training.") hooks = [sync_optimizer.make_session_run_hook(is_chief)] def make_update_op(update_thunks): update_ops = [thunk() for thunk in update_thunks] return tf.group(*update_ops) if is_chief: cov_update_op = make_update_op(cov_update_thunks) with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( tf.equal(tf.mod(global_step, invert_every), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) with tf.control_dependencies([inverse_op]): train_op = sync_optimizer.minimize(loss, global_step=global_step) else: train_op = sync_optimizer.minimize(loss, global_step=global_step) ``` In the second strategy, each worker's training loop is responsible for executing only one of K-FAC's three training ops, 1. Compute gradients. 1. Workers updating covariance matrices can asynchronously update the moving average similar to the way asynchronous SGD updates weights. 1. Workers inverting the preconditioning matrix can independently and asynchronously invert its blocks, one at a time. Blocks are chosen according to a randomly shuffled queue. ```python optimizer = opt.KfacOptimizer(...) inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values()) sync_optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) train_op = sync_optimizer.minimize(loss, global_step=global_step) with tf.train.MonitoredTrainingSession(...) as sess: while not sess.should_stop(): if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): learning_op = optimizer.cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): learning_op = inv_update_queue.next_op(sess) global_step_, loss_, statistics_, _ = sess.run( [global_step, loss, statistics, learning_op]) ``` ## TIPS 1. Check the [hyper params tuning][hp_tune] section for more details on tuning various KFAC parameters. [hp_tune]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md ================================================ FILE: docs/examples/parameters.md ================================================ # K-FAC Parameters. ## Table of Contents * [Damping](#damping) * [Learning Rate](#learning-rate) * [Subsample covariance computation](#subsample-covariance-computation) * [KFAC norm constraint](#kfac-norm-constraint) * [Covariance decay](#covariance-decay) * [Train batch size](#train-batch-size)
We list below various parameters which can be tuned to improve training and run time performance of K-FAC. ## Damping Damping is a crucial aspect of K-FAC, as it is for any second order optimization/natural gradient method. Broadly speaking, it refers to the practice of penalizing or constraining the size of the update in various ways so that it doesn't leave the local region where the quadratic approximation to the objective (which is used to compute the update) remains accurate. This region commonly referred to as the "trust region". In some literature damping is called "regularization" although we will avoid that term due to its related but distinct meaning as a method to combat overfitting. The damping strategy used in KFAC is to (approximately) add a multiple of the identity to the Fisher before inverting it. This is essentially equivalent to enforcing that the update lie in a spherical trust region centered at the current location in parameter space. The `damping` parameter represents the multiple of identity which is used. Higher values correspond to smaller trust regions, although the precise relationship between `damping` and the size of the trust region depends on the scale of the objective, and will vary from iteration to iteration. (If the loss function is multiplied by scalar 'alpha' then damping should be multiplied by 'alpha' as well.) Higher values of `damping` can allow higher learning rates, but as damping tends to infinity the KFAC updates will start to resemble regular gradient descent updates (scaled by `1/damping`). The `damping` parameter depends on the scale of the loss function. `damping` is a critical parameter that needs to be tuned. Options for tuning include a grid sweep (must be simultaneous with learning rate optimization - NOT independent) or auto-tuned using the Levenberg-Marquardt (LM) algorithm (see the [`Auto Damping`][auto_damping] section for further details). For grid sweeps a typical range to consider would be logarithmically spaced values between `1e-5` to `100`, although the optimal value could be any non-negative real number in principle (because the scale of the loss is arbitrary). Another option for tuning `damping` is [`Population based training`][PBT] (PBT). Refer to section `6` of the [KFAC paper][kfac_paper] for a more detailed discussion of damping and how it can be used/tuned in KFAC [auto_damping]: https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md [PBT]: https://arxiv.org/abs/1711.09846 [kfac_paper]: https://arxiv.org/pdf/1503.05671.pdf ## Learning Rate Typically sweep over values in the range 1e-5 to 100. It is important to tune the learning in conjunction with damping, since the two are closely coupled (higher damping allows higher learning rates). The learning rate can also be tuned using PBT. Note that the optimal learning rate will be generally different from the learning rate used for SGD/RMSProp/Adam optimizer. ## Subsample covariance computation If you are using Conv layers and observe that the KFAC iterations is significantly slower than Adam or if you run out of memory then a possible remedy is to use subsampling in the covariance computation. To turn on subsampling set `kfac_ff.sub_sample_inputs` to `True` and `kfac_ff.sub_sample_outer_products` to `True`. The former flag subsamples the batch of inputs used for covariance computation and the later flag subsamples extracted patches based on the size of the covariance matrix. Check the documentation of `tensorflow_kfac.fisher_factors` for detailed explanation of various subsampling parameters. Also check [`Distributed training`][dist_train] section for how to distribute the computation of these ops over multiple devices. [dist_train]: https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md ## KFAC norm constraint Scales the K-FAC update so that its approximate Fisher norm is bounded. Typically use an initial value of 1.0 and tune it using PBT or perform grid search. Norm constraint can used as an alternative to learning rate schedules. See Section 5 of the [Distributed Second-Order Optimization using Kronecker-Factored Approximations][ba_paper] paper for further details. [ba_paper]: https://jimmylba.github.io/papers/nsync.pdf ## Covariance decay During the course of the algorithm, an exponential moving average tracks statistics for each layer. Slower decays mean that the statistics are based on more data, but will suffer more from the issue of staleness (because of the changing model parameters). This parameter can usually be left at its default value but may occasionally matter for some problems. In such cases some reasonable values to sweep over are `[0.9, 0.95, 0.99, 0.999]`. ## Train batch size Typically try using a larger batch size compared to training with SGD/RMSprop/Adam. ================================================ FILE: docs/index.md ================================================ # Home Kronecker factored approximate curvature **K-FAC in TensorFlow** is an implementation of K-FAC, an approximate second-order optimization method, in TensorFlow. K-FAC can converge much faster than SGD or Adam on certain neural network architectures (especially when using larger batch sizes), but may be closer in performance on other architectures (such as ResNets). ## Table of Contents * [What is K-FAC?](#what-is-k-fac) * [Why should I use K-FAC?](#why-should-i-use-k-fac) * [How do I use K-FAC?](#how-do-i-use-k-fac) ## What is K-FAC? K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation to the [Natural Gradient][natural_gradient] algorithm designed specifically for neural networks. It maintains an approximation to the [Fisher Information matrix][fisher_information], whose inverse is used as a preconditioner for (stochastic) gradient descent. K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations. However it is slightly more restrictive compared to SGD, Adam as it makes some assumptions on the structure of the model and the loss function. Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What are the weights for layer i?"). As such, you must add some additional code while constructing your model to use K-FAC. [natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746 [fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form ## Why should I use K-FAC? K-FAC can take advantage of the curvature of the optimization problem, resulting in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See reference code [here][autoencoder-code] and plots comparing KFAC with SGD below. ![](https://github.com/tensorflow/kfac/tree/master/kfac/g3doc/sgd_comparison.png?raw=true) [autoencoder-code]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/autoencoder_mnist.py ## How do I use K-FAC? Using K-FAC requires three steps, 1. Registering layer inputs, weights, and pre-activations with a `kfac.LayerCollection`. 2. Register loss functions. 3. Minimizing the loss with a `kfac.PeriodicInvCovUpdateKfacOpt`. ```python import kfac # Build model. w = tf.get_variable("w", ...) b = tf.get_variable("b", ...) logits = tf.matmul(x, w) + b loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) # Register loss. layer_collection = kfac.LayerCollection() layer_collection.register_softmax_cross_entropy_loss(logits) # Register layers. layer_collection.auto_register_layers() # Construct training ops. optimizer = kfac.PeriodicInvCovUpdateKfacOpt(..., layer_collection=layer_collection) train_op = optimizer.minimize(loss) # Minimize loss. with tf.Session() as sess: ... sess.run([train_op]) ``` Check out the Convnet training [example][convexamplesec] for more details. Also check [`PeriodicInvCovUpdate`][periodicincovupdate] optimizer to see how the covariance and invariance ops placement and execution can be handled automatically. [convexamplesec]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md [periodicincovupdate]: https://github.com/tensorflow/kfac/tree/master/kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py ## Table of contents * [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md) * User Guide * [Keras](https://github.com/tensorflow/kfac/tree/master/kfac/python/keras/README.md) * [Convolutional](https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md) * [Auto damping](https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md) * [Distributed Training](https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md) * [Parameters](https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md) * [Applications](https://github.com/tensorflow/kfac/tree/master/docs/applications.md) * [Some KFAC-Papers](https://github.com/tensorflow/kfac/tree/master/docs/papers.md) * [Contact](https://github.com/tensorflow/kfac/tree/master/docs/contact.md) ================================================ FILE: docs/papers.md ================================================ * Martens, James, and Roger Grosse. ["Optimizing neural networks with Kronecker-factored approximate curvature."][kfac_paper] International Conference on Machine Learning. 2015. * Grosse, Roger, and James Martens. ["A Kronecker-factored approximate Fisher matrix for convolution layers."][kfac_conv_paper] International Conference on Machine Learning. 2016. * Ba, Jimmy, Roger Grosse, and James Martens. ["Distributed Second-Order Optimization using Kronecker-Factored Approximations."][distributed_kfac] (2016). * James Martens, Jimmy Ba, Matt Johnson. ["Kronecker-factored Curvature Approximations for Recurrent Neural Networks."][kfac_rnn_paper] ICLR. 2018. [kfac_paper]: https://arxiv.org/abs/1503.05671 [kfac_conv_paper]: https://arxiv.org/abs/1602.01407 [kfac_rnn_paper]: https://openreview.net/forum?id=HyMTkQZAb [distributed_kfac]: https://openreview.net/forum?id=SkkTMpjex ================================================ FILE: docs/sitemap.md ================================================ * [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md) * User Guide * [Convolutional](https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md) * [Auto damping](https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md) * [Distributed Training](https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md) * [Parameters](https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md) * [Applications](https://github.com/tensorflow/kfac/tree/master/docs/applications.md) * [Some KFAC-Papers](https://github.com/tensorflow/kfac/tree/master/docs/papers.md) * [Contact](https://github.com/tensorflow/kfac/tree/master/docs/contact.md) ================================================ FILE: kfac/__init__.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Kronecker-factored Approximate Curvature Optimizer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long from kfac.python import keras from kfac.python.ops import curvature_matrix_vector_products from kfac.python.ops import estimator from kfac.python.ops import fisher_blocks from kfac.python.ops import fisher_factors from kfac.python.ops import layer_collection from kfac.python.ops import linear_operator from kfac.python.ops import loss_functions from kfac.python.ops import op_queue from kfac.python.ops import optimizer from kfac.python.ops import placement from kfac.python.ops import utils from kfac.python.ops.kfac_utils import async_inv_cov_update_kfac_opt from kfac.python.ops.kfac_utils import data_reader from kfac.python.ops.kfac_utils import data_reader_alt from kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt from kfac.python.ops.tensormatch import graph_matcher from kfac.python.ops.tensormatch import graph_search # pylint: enable=unused-import # pylint: disable=invalid-name LayerCollection = layer_collection.LayerCollection KfacOptimizer = optimizer.KfacOptimizer PeriodicInvCovUpdateKfacOpt = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt AsyncInvCovUpdateKfacOpt = async_inv_cov_update_kfac_opt.AsyncInvCovUpdateKfacOpt CurvatureMatrixVectorProductComputer = curvature_matrix_vector_products.CurvatureMatrixVectorProductComputer # pylint: enable=invalid-name, line-too-long ================================================ FILE: kfac/examples/__init__.py ================================================ ================================================ FILE: kfac/examples/autoencoder_mnist.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Full implementation of deep autoencoder experiment from original K-FAC paper. This script demonstrates training using KFAC optimizer, updating the damping parameter according to the Levenberg-Marquardt rule, and using the quadratic model method for adapting the learning rate and momentum parameters. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math # Dependency imports from absl import flags import kfac import sonnet as snt import tensorflow.compat.v1 as tf from kfac.examples import mnist from kfac.python.ops.kfac_utils import data_reader from kfac.python.ops.kfac_utils import data_reader_alt # Model parameters _ENCODER_SIZES = [1000, 500, 250, 30] _DECODER_SIZES = [250, 500, 1000] _NONLINEARITY = tf.tanh # Note: sigmoid cannot be used with the default init. _WEIGHTS_INITIALIZER = None # Default init flags.DEFINE_integer('train_steps', 10000, 'Number of training steps.') flags.DEFINE_integer('inverse_update_period', 5, '# of steps between computing inverse of Fisher factor ' 'matrices.') flags.DEFINE_integer('cov_update_period', 1, '# of steps between computing covaraiance matrices.') flags.DEFINE_integer('damping_adaptation_interval', 5, '# of steps between updating the damping parameter.') flags.DEFINE_integer('num_burnin_steps', 5, 'Number of steps at the ' 'start of training where the optimizer will only perform ' 'cov updates.') flags.DEFINE_integer('seed', 12345, 'Random seed') flags.DEFINE_float('learning_rate', 3e-3, 'Learning rate to use when lrmu_adaptation="off".') flags.DEFINE_float('momentum', 0.9, 'Momentum decay value to use when ' 'lrmu_adaptation="off" or "only_lr".') flags.DEFINE_float('damping', 1e-2, 'The fixed damping value to use. This is ' 'ignored if adapt_damping is True.') flags.DEFINE_float('l2_reg', 1e-5, 'L2 regularization applied to weight matrices.') flags.DEFINE_boolean('update_damping_immediately', True, 'Adapt the damping ' 'immediately after the parameter update (i.e. in the same ' 'sess.run() call). Only safe if everything is a resource ' 'variable.') flags.DEFINE_boolean('use_batch_size_schedule', True, 'If True then we use the growing mini-batch schedule from ' 'the original K-FAC paper.') flags.DEFINE_integer('batch_size', 1024, 'The size of the mini-batches to use if not using the ' 'schedule.') flags.DEFINE_string('lrmu_adaptation', 'on', 'If set to "on" then we use the quadratic model ' 'based learning-rate and momentum adaptation method from ' 'the original paper. Note that this only works well in ' 'practice when use_batch_size_schedule=True. Can also ' 'be set to "off" and "only_lr", which turns ' 'it off, or uses a version where the momentum parameter ' 'is fixed (resp.).') flags.DEFINE_boolean('use_alt_data_reader', True, 'If True we use the alternative data reader for MNIST ' 'that is faster for small datasets.') flags.DEFINE_string('device', '/gpu:0', 'The device to run the major ops on.') flags.DEFINE_boolean('adapt_damping', True, 'If True we use the LM rule for damping adaptation as ' 'described in the original K-FAC paper.') # When using damping adaptation it is advisable to start with a high # value. This value is probably far too high to use for most neural nets # if you aren't using damping adaptation. (Although it always depends on # the scale of the loss.) flags.DEFINE_float('initial_damping', 150.0, 'The initial damping value to use when adapt_damping is ' 'True.') flags.DEFINE_string('optimizer', 'kfac', 'The optimizer to use. Can be kfac or adam. If adam is ' 'used the various K-FAC hyperparameter map roughly on to ' 'their Adam equivalents.') flags.DEFINE_boolean('auto_register_layers', True, 'If True we use the automatic registration feature ' 'which relies on scanning the TF graph. Otherwise ' 'registration is done manually by this script during ' 'the construction of the model.') flags.DEFINE_boolean('use_keras_model', False, 'If True, we use a Keras version of the autoencoder ' 'model. Only works when auto_register_layers=True.') flags.DEFINE_boolean('use_sequential_for_keras', True, 'If True, we construct the Keras model using the ' 'Sequential class.') flags.DEFINE_boolean('use_control_flow_v2', False, 'If True, we use Control ' 'Flow V2. Defaults to False.') FLAGS = flags.FLAGS def make_train_op(minibatch, batch_size, batch_loss, layer_collection, loss_fn, prev_train_batch=None, placement_strategy=None, print_logs=False, tf_replicator=None): """Constructs optimizer and train op. Args: minibatch: A list/tuple of Tensors (typically representing the current mini-batch of input images and labels). batch_size: Tensor of shape (). Size of the training mini-batch. batch_loss: Tensor of shape (). Mini-batch loss tensor. layer_collection: LayerCollection object. Registry for model parameters. Required when using a K-FAC optimizer. loss_fn: Function which takes as input a mini-batch and returns the loss. prev_train_batch: `Tensor` of the previous training batch, can be accessed from the data_reader.CachedReader cached_batch property. (Default: None) placement_strategy: `str`, the placement_strategy argument for `KfacOptimizer`. (Default: None) print_logs: `Bool`. If True we print logs using K-FAC's built-in tf.print-based logs printer. (Default: False) tf_replicator: A Replicator object or None. If not None, K-FAC will set itself up to work inside of the provided TF-Replicator object. (Default: None) Returns: train_op: Op that can be used to update model parameters. optimizer: Optimizer used to produce train_op. Raises: ValueError: If layer_collection is None when K-FAC is selected as an optimization method. """ global_step = tf.train.get_or_create_global_step() if FLAGS.optimizer == 'kfac': if FLAGS.lrmu_adaptation == 'on': learning_rate = None momentum = None momentum_type = 'qmodel' elif FLAGS.lrmu_adaptation == 'only_lr': learning_rate = None momentum = FLAGS.momentum momentum_type = 'qmodel_fixedmu' elif FLAGS.lrmu_adaptation == 'off': learning_rate = FLAGS.learning_rate momentum = FLAGS.momentum # momentum_type = 'regular' momentum_type = 'adam' if FLAGS.adapt_damping: damping = FLAGS.initial_damping else: damping = FLAGS.damping optimizer = kfac.PeriodicInvCovUpdateKfacOpt( invert_every=FLAGS.inverse_update_period, cov_update_every=FLAGS.cov_update_period, learning_rate=learning_rate, damping=damping, cov_ema_decay=0.95, momentum=momentum, momentum_type=momentum_type, layer_collection=layer_collection, batch_size=batch_size, num_burnin_steps=FLAGS.num_burnin_steps, adapt_damping=FLAGS.adapt_damping, l2_reg=FLAGS.l2_reg, placement_strategy=placement_strategy, print_logs=print_logs, tf_replicator=tf_replicator, # Note that many of the arguments below don't do anything when # adapt_damping=False. update_damping_immediately=FLAGS.update_damping_immediately, is_chief=True, prev_train_batch=prev_train_batch, # We don't actually need this unless # update_damping_immediately is # False. loss=batch_loss, loss_fn=loss_fn, damping_adaptation_decay=0.95, damping_adaptation_interval=FLAGS.damping_adaptation_interval, min_damping=1e-6, train_batch=minibatch, ) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.learning_rate, beta1=FLAGS.momentum, epsilon=FLAGS.damping, beta2=0.99) return optimizer.minimize(batch_loss, global_step=global_step), optimizer class AutoEncoder(snt.AbstractModule): """Simple autoencoder module.""" def __init__(self, input_size, regularizers=None, initializers=None, custom_getter=None, name='AutoEncoder'): super(AutoEncoder, self).__init__(custom_getter=custom_getter, name=name) if initializers is None: initializers = {'w': tf.glorot_uniform_initializer(), 'b': tf.zeros_initializer()} if regularizers is None: regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w), 'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),} with self._enter_variable_scope(): self._encoder = snt.nets.MLP( output_sizes=_ENCODER_SIZES, regularizers=regularizers, initializers=initializers, custom_getter=custom_getter, activation=_NONLINEARITY, activate_final=False) self._decoder = snt.nets.MLP( output_sizes=_DECODER_SIZES + [input_size], regularizers=regularizers, initializers=initializers, custom_getter=custom_getter, activation=_NONLINEARITY, activate_final=False) def _build(self, inputs): code = self._encoder(inputs) output = self._decoder(code) return output class MLPManualReg(snt.AbstractModule): def __init__(self, output_sizes, regularizers=None, initializers=None, custom_getter=None, activation=_NONLINEARITY, activate_final=False, name='MLP'): super(MLPManualReg, self).__init__(custom_getter=custom_getter, name=name) self._output_sizes = output_sizes self._activation = activation self._activate_final = activate_final with self._enter_variable_scope(): self._layers = [snt.Linear(self._output_sizes[i], name='linear_{}'.format(i), initializers=initializers, regularizers=regularizers, custom_getter=custom_getter, use_bias=True) for i in range(len(self._output_sizes))] def _build(self, inputs, layer_collection=None): net = inputs for i in range(len(self._output_sizes)): layer_inputs = net net = self._layers[i](net) layer_outputs = net params = (self._layers[i].w, self._layers[i].b) if layer_collection is not None: layer_collection.register_fully_connected(params, layer_inputs, layer_outputs, reuse=False) if i < len(self._output_sizes) - 1 or self._activate_final: net = self._activation(net) return net class AutoEncoderManualReg(snt.AbstractModule): """Simple autoencoder module.""" def __init__(self, input_size, regularizers=None, initializers=None, custom_getter=None, name='AutoEncoder'): super(AutoEncoderManualReg, self).__init__(custom_getter=custom_getter, name=name) if initializers is None: initializers = {'w': tf.glorot_uniform_initializer(), 'b': tf.zeros_initializer()} if regularizers is None: regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w), 'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),} with self._enter_variable_scope(): self._encoder = MLPManualReg( output_sizes=_ENCODER_SIZES, regularizers=regularizers, initializers=initializers, custom_getter=custom_getter, activation=_NONLINEARITY, activate_final=False) self._decoder = MLPManualReg( output_sizes=_DECODER_SIZES + [input_size], regularizers=regularizers, initializers=initializers, custom_getter=custom_getter, activation=_NONLINEARITY, activate_final=False) def _build(self, inputs, layer_collection=None): code = self._encoder(inputs, layer_collection=layer_collection) output = self._decoder(code, layer_collection=layer_collection) return output def get_keras_autoencoder(**input_kwargs): """Returns autoencoder made with Keras. Args: **input_kwargs: Arguments to pass to tf.keras.layers.Input. You must include either the 'shape' or 'tensor' kwarg. Returns: A tf.keras.Model, the Autoencoder. """ layers = tf.keras.layers regularizers = tf.keras.regularizers dense_kwargs = { 'kernel_initializer': tf.glorot_uniform_initializer(), 'bias_initializer': tf.zeros_initializer(), 'kernel_regularizer': regularizers.l2(l=FLAGS.l2_reg), 'bias_regularizer': regularizers.l2(l=FLAGS.l2_reg), } if FLAGS.use_sequential_for_keras: model = tf.keras.Sequential() # Create Encoder model.add(layers.Input(**input_kwargs)) for size in _ENCODER_SIZES[:-1]: model.add(layers.Dense( size, activation=_NONLINEARITY, **dense_kwargs)) model.add(layers.Dense(_ENCODER_SIZES[-1], **dense_kwargs)) # Create Decoder for size in _DECODER_SIZES: model.add(layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)) model.add(layers.Dense(784, **dense_kwargs)) else: # Make sure you always wrap the input in keras inputs = layers.Input(**input_kwargs) x = inputs # Create Encoder for size in _ENCODER_SIZES[:-1]: x = layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)(x) x = layers.Dense(_ENCODER_SIZES[-1], **dense_kwargs)(x) # Create Decoder for size in _DECODER_SIZES: x = layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)(x) x = layers.Dense(784, **dense_kwargs)(x) model = tf.keras.Model(inputs=inputs, outputs=x) return model def compute_squared_error(logits, targets): """Compute mean squared error.""" return tf.reduce_sum( tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0)) def compute_loss(logits=None, labels=None, return_error=False, model=None): """Compute loss value.""" if FLAGS.use_keras_model: total_regularization_loss = tf.add_n(model.losses) else: graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_regularization_loss = tf.add_n(graph_regularizers) loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0)) regularized_loss = loss + total_regularization_loss if return_error: squared_error = compute_squared_error(logits, labels) return regularized_loss, squared_error return regularized_loss def load_mnist(): """Creates MNIST dataset and wraps it inside cached data reader. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. num_examples: int. The number of training examples. """ # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. if not FLAGS.use_alt_data_reader: # Version 1 using data_reader.py (slow!) dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True) if FLAGS.use_batch_size_schedule: max_batch_size = num_examples else: max_batch_size = FLAGS.batch_size # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat() .batch(max_batch_size).prefetch(5)) dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # This version of CachedDataReader requires the dataset to be shuffled return data_reader.CachedDataReader(dataset, max_batch_size), num_examples else: # Version 2 using data_reader_alt.py (faster) images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=True) dataset = (images, labels) # This version of CachedDataReader requires the dataset to NOT be shuffled return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples def _get_batch_size_schedule(minibatch_maxsize): """Returns training batch size schedule.""" minibatch_maxsize_targetiter = 500 minibatch_startsize = 1000 div = (float(minibatch_maxsize_targetiter-1) / math.log(float(minibatch_maxsize)/minibatch_startsize, 2)) return [ min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize) for k in range(minibatch_maxsize_targetiter) ] def construct_train_quants(): """Returns tensors and optimizer required to run the autoencoder.""" with tf.device(FLAGS.device): # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') minibatch = cached_reader(batch_size) features, _ = minibatch if FLAGS.auto_register_layers: if FLAGS.use_keras_model: training_model = get_keras_autoencoder(tensor=features) else: training_model = AutoEncoder(784) else: training_model = AutoEncoderManualReg(784) layer_collection = kfac.LayerCollection() def loss_fn(minibatch, logits=None, return_error=False): features, _ = minibatch if logits is None: logits = training_model(features) return compute_loss( logits=logits, labels=features, return_error=return_error, model=training_model) if FLAGS.use_keras_model: logits = training_model.output else: if FLAGS.auto_register_layers: logits = training_model(features) else: logits = training_model(features, layer_collection=layer_collection) (batch_loss, batch_error) = loss_fn( minibatch, logits=logits, return_error=True) # Make sure never to confuse this with register_softmax_cross_entropy_loss! layer_collection.register_sigmoid_cross_entropy_loss(logits, seed=FLAGS.seed + 1) if FLAGS.auto_register_layers: layer_collection.auto_register_layers() # Make training op train_op, opt = make_train_op( minibatch, batch_size, batch_loss, layer_collection, loss_fn=loss_fn, prev_train_batch=cached_reader.cached_batch) return train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size def main(_): # If using update_damping_immediately resource variables must be enabled. # Would recommend always enabling them anyway. if FLAGS.update_damping_immediately: tf.enable_resource_variables() if FLAGS.use_control_flow_v2: tf.enable_control_flow_v2() if not FLAGS.auto_register_layers and FLAGS.use_keras_model: raise ValueError('Require auto_register_layers=True when using Keras ' 'model.') tf.set_random_seed(FLAGS.seed) (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size) = construct_train_quants() global_step = tf.train.get_or_create_global_step() if FLAGS.optimizer == 'kfac': # We need to put the control depenency on train_op here so that we are # guaranteed to get the up-to-date values of these various quantities. # Otherwise there is a race condition and we might get the old values, # nondeterministically. Another solution would be to get these values in # a separate sess.run call, but this can sometimes cause problems with # training frameworks that use hooks (see the comments below). with tf.control_dependencies([train_op]): learning_rate = opt.learning_rate momentum = opt.momentum damping = opt.damping rho = opt.rho qmodel_change = opt.qmodel_change # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # It's good practice to put everything into a single sess.run call. The # reason is that certain "training frameworks" like to run hooks at each # sess.run call, and there is an implicit expectation there will only # be one sess.run call every "iteration" of the "optimizer". For example, # a framework might try to print the loss at each sess.run call, causing # the mini-batch to be advanced, thus completely breaking the "cached # batch" mechanism that the damping adaptation method may rely on. (Plus # there will also be the extra cost of having to reevaluate the loss # twice.) That being said we don't completely do that here because it's # inconvenient. # Train model. with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30, config=config) as sess: for _ in range(FLAGS.train_steps): i = sess.run(global_step) if FLAGS.use_batch_size_schedule: batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)] else: batch_size_ = FLAGS.batch_size if FLAGS.optimizer == 'kfac': (_, batch_loss_, batch_error_, learning_rate_, momentum_, damping_, rho_, qmodel_change_) = sess.run([train_op, batch_loss, batch_error, learning_rate, momentum, damping, rho, qmodel_change], feed_dict={batch_size: batch_size_}) else: _, batch_loss_, batch_error_ = sess.run( [train_op, batch_loss, batch_error], feed_dict={batch_size: batch_size_}) # Print training stats. tf.logging.info( 'iteration: %d', i) tf.logging.info( 'mini-batch size: %d | mini-batch loss = %f | mini-batch error = %f ', batch_size_, batch_loss_, batch_error_) if FLAGS.optimizer == 'kfac': tf.logging.info( 'learning_rate = %f | momentum = %f', learning_rate_, momentum_) tf.logging.info( 'damping = %f | rho = %f | qmodel_change = %f', damping_, rho_, qmodel_change_) tf.logging.info('----') if __name__ == '__main__': tf.disable_v2_behavior() tf.app.run(main) ================================================ FILE: kfac/examples/autoencoder_mnist_tpu_estimator.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implementation of Deep AutoEncoder from Martens & Grosse (2015). This script demonstrates training on TPUs with TPU Estimator using the KFAC optimizer, updating the damping parameter according to the Levenberg-Marquardt rule, and using the quadratic model method for adapting the learning rate and momentum parameters. See third_party/tensorflow_kfac/google/examples/ae_tpu_xm_launcher.py for an example Borg launch script. If you can't access this launch script, some important things to know about running K-FAC on TPUs (at least for this example) are that you must use higher-precision matrix multiplications. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports from absl import flags import kfac import tensorflow.compat.v1 as tf from tensorflow.contrib import tpu as contrib_tpu from kfac.examples import autoencoder_mnist from kfac.examples import mnist flags.DEFINE_integer('save_checkpoints_steps', 500, 'Number of iterations between model checkpoints.') flags.DEFINE_integer('iterations_per_loop', 100, 'Number of iterations in a TPU training loop.') flags.DEFINE_string('model_dir', '', 'Model dir.') flags.DEFINE_string('master', None, 'GRPC URL of the master ' '(e.g. grpc://ip.address.of.tpu:8470).') FLAGS = flags.FLAGS def make_train_op(minibatch, batch_loss, layer_collection, loss_fn): """Constructs optimizer and train op. Args: minibatch: Tuple[Tensor, Tensor] representing the current batch of input images and labels. batch_loss: Tensor of shape (), Loss with respect to minibatch to be minimzed. layer_collection: LayerCollection object. Registry for model parameters. Required when using a K-FAC optimizer. loss_fn: A function that when called constructs the graph to compute the model loss on the current minibatch. Returns a Tensor of the loss scalar. Returns: train_op: Op that can be used to update model parameters. optimizer: The KFAC optimizer used to produce train_op. Raises: ValueError: If layer_collection is None when K-FAC is selected as an optimization method. """ # Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own # cross-replica syncronization automatically! return autoencoder_mnist.make_train_op( minibatch=minibatch, batch_size=minibatch[0].get_shape().as_list()[0], batch_loss=batch_loss, layer_collection=layer_collection, loss_fn=loss_fn, prev_train_batch=None, placement_strategy='replica_round_robin', ) def compute_squared_error(logits, targets): """Compute mean squared error.""" return tf.reduce_sum( tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0)) def compute_loss(logits, labels): """Compute loss value.""" graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_regularization_loss = tf.add_n(graph_regularizers) loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0)) regularized_loss = loss + total_regularization_loss return regularized_loss def mnist_input_fn(params): dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True) # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = ( dataset.shuffle(num_examples).repeat().batch( params['batch_size'], drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)) return dataset def print_tensors(**tensors): """Host call function to print Tensors from the TPU during training.""" print_op = tf.no_op() for name in sorted(tensors): with tf.control_dependencies([print_op]): tensor = tensors[name] if name in ['error', 'loss']: tensor = tf.reduce_mean(tensor) print_op = tf.Print(tensor, [tensor], message=name + '=') with tf.control_dependencies([print_op]): return tf.Print(0., [0.], message='------') def _model_fn(features, labels, mode, params): """Estimator model_fn for an autoencoder with adaptive damping.""" del params layer_collection = kfac.LayerCollection() training_model_fn = autoencoder_mnist.AutoEncoder(784) def loss_fn(minibatch, logits=None): """Compute the model loss given a batch of inputs. Args: minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images and labels. logits: `Tensor` for the current batch of logits. If None then reuses the AutoEncoder to compute them. Returns: `Tensor` for the batch loss. """ features, labels = minibatch del labels if logits is None: # Note we do not need to do anything like # `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):` # here because Sonnet takes care of variable reuse for us as long as we # call the same `training_model_fn` module. Otherwise we would need to # use variable reusing here. logits = training_model_fn(features) batch_loss = compute_loss(logits=logits, labels=features) return batch_loss logits = training_model_fn(features) pre_update_batch_loss = loss_fn((features, labels), logits=logits) pre_update_batch_error = compute_squared_error(logits, features) if mode == tf.estimator.ModeKeys.TRAIN: # Make sure never to confuse this with register_softmax_cross_entropy_loss! layer_collection.register_sigmoid_cross_entropy_loss(logits, seed=FLAGS.seed + 1) layer_collection.auto_register_layers() global_step = tf.train.get_or_create_global_step() train_op, kfac_optimizer = make_train_op( (features, labels), pre_update_batch_loss, layer_collection, loss_fn) tensors_to_print = { 'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0), 'momentum': tf.expand_dims(kfac_optimizer.momentum, 0), 'damping': tf.expand_dims(kfac_optimizer.damping, 0), 'global_step': tf.expand_dims(global_step, 0), 'loss': tf.expand_dims(pre_update_batch_loss, 0), 'error': tf.expand_dims(pre_update_batch_error, 0), } if FLAGS.adapt_damping: tensors_to_print['qmodel_change'] = tf.expand_dims( kfac_optimizer.qmodel_change, 0) tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0) return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=pre_update_batch_loss, train_op=train_op, host_call=(print_tensors, tensors_to_print), eval_metrics=None) else: # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}: return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=pre_update_batch_loss, eval_metrics=None) def make_tpu_run_config(master, seed, model_dir, iterations_per_loop, save_checkpoints_steps): return contrib_tpu.RunConfig( master=master, evaluation_master=master, model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps, cluster=None, tf_random_seed=seed, tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop)) def main(argv): if FLAGS.use_control_flow_v2: tf.enable_control_flow_v2() del argv # Unused. tf.set_random_seed(FLAGS.seed) # Invert using cholesky decomposition + triangular solve. This is the only # code path for matrix inversion supported on TPU right now. kfac.utils.set_global_constants(posdef_inv_method='cholesky') kfac.fisher_factors.set_global_constants( eigenvalue_decomposition_threshold=10000) config = make_tpu_run_config( FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop, FLAGS.save_checkpoints_steps) estimator = contrib_tpu.TPUEstimator( use_tpu=True, model_fn=_model_fn, config=config, train_batch_size=FLAGS.batch_size, eval_batch_size=1024) estimator.train( input_fn=mnist_input_fn, max_steps=FLAGS.train_steps, hooks=[]) if __name__ == '__main__': tf.disable_v2_behavior() tf.app.run(main) ================================================ FILE: kfac/examples/autoencoder_mnist_tpu_strategy.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implementation of Deep AutoEncoder from Martens & Grosse (2015). This script demonstrates training on TPUs with TPUStrategy using the KFAC optimizer, updating the damping parameter according to the Levenberg-Marquardt rule, and using the quadratic model method for adapting the learning rate and momentum parameters. See third_party/tensorflow_kfac/google/examples/ae_tpu_xm_launcher.py for an example Borg launch script. If you can't access this launch script, some important things to know about running K-FAC on TPUs (at least for this example) are that you must use high-precision matrix multiplications. iterations_per_loop is not relevant when using TPU Strategy, but you must set it to 1 when using TPU Estimator. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports from absl import flags import kfac import tensorflow.compat.v1 as tf from kfac.examples import autoencoder_mnist from kfac.examples import mnist # TODO(znado): figure out the bug with this and update_damping_immediately=True. # TODO(znado): Add checkpointing code to the training loop. flags.DEFINE_integer('save_checkpoints_steps', 500, 'Number of iterations between model checkpoints.') flags.DEFINE_string('model_dir', '', 'Model dir.') # iterations_per_loop is not used with TPU Strategy. We keep the flag so the # Estimator launching script can be used. flags.DEFINE_integer('iterations_per_loop', 1, 'Number of iterations in a TPU training loop.') flags.DEFINE_string('master', None, 'GRPC URL of the master ' '(e.g. grpc://ip.address.of.tpu:8470).') FLAGS = flags.FLAGS def make_train_op(minibatch, batch_loss, layer_collection, loss_fn): """Constructs optimizer and train op. Args: minibatch: Tuple[Tensor, Tensor] representing the current batch of input images and labels. batch_loss: Tensor of shape (), Loss with respect to minibatch to be minimzed. layer_collection: LayerCollection object. Registry for model parameters. Required when using a K-FAC optimizer. loss_fn: A function that when called constructs the graph to compute the model loss on the current minibatch. Returns a Tensor of the loss scalar. Returns: train_op: Op that can be used to update model parameters. optimizer: The KFAC optimizer used to produce train_op. Raises: ValueError: If layer_collection is None when K-FAC is selected as an optimization method. """ # Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own # cross-replica syncronization automatically! return autoencoder_mnist.make_train_op( minibatch=minibatch, batch_size=minibatch[0].get_shape().as_list()[0], batch_loss=batch_loss, layer_collection=layer_collection, loss_fn=loss_fn, prev_train_batch=None, placement_strategy='replica_round_robin', ) def compute_squared_error(logits, targets): """Compute mean squared error.""" return tf.reduce_sum( tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0)) def compute_loss(logits, labels, model): """Compute loss value.""" loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) regularization_loss = tf.reduce_sum(model.losses) crossentropy_loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0)) return crossentropy_loss + regularization_loss def mnist_input_fn(batch_size): dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True) # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples) .repeat() .batch(batch_size, drop_remainder=True) .prefetch(tf.data.experimental.AUTOTUNE)) return dataset def _train_step(batch): """Estimator model_fn for an autoencoder with adaptive damping.""" features, labels = batch model = autoencoder_mnist.get_keras_autoencoder(tensor=features) def loss_fn(minibatch, logits=None): """Compute the model loss given a batch of inputs. Args: minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images and labels. logits: `Tensor` for the current batch of logits. If None then reuses the AutoEncoder to compute them. Returns: `Tensor` for the batch loss. """ features, labels = minibatch del labels if logits is None: logits = model(features) batch_loss = compute_loss(logits=logits, labels=features, model=model) return batch_loss logits = model.output pre_update_batch_loss = loss_fn((features, labels), logits) pre_update_batch_error = compute_squared_error(logits, features) # binary_crossentropy corresponds to sigmoid_crossentropy. layer_collection = kfac.keras.utils.get_layer_collection( model, 'binary_crossentropy', seed=FLAGS.seed + 1) global_step = tf.train.get_or_create_global_step() train_op, kfac_optimizer = make_train_op( (features, labels), pre_update_batch_loss, layer_collection, loss_fn) tensors_to_print = { 'learning_rate': kfac_optimizer.learning_rate, 'momentum': kfac_optimizer.momentum, 'damping': kfac_optimizer.damping, 'global_step': global_step, 'loss': pre_update_batch_loss, 'error': pre_update_batch_error, } if FLAGS.adapt_damping: tensors_to_print['qmodel_change'] = kfac_optimizer.qmodel_change tensors_to_print['rho'] = kfac_optimizer.rho with tf.control_dependencies([train_op]): return {k: tf.identity(v) for k, v in tensors_to_print.items()} def train(): """Trains the Autoencoder using TPU Strategy.""" cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.master) tf.tpu.experimental.initialize_tpu_system(cluster_resolver) tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) with tpu_strategy.scope(): data = mnist_input_fn(batch_size=FLAGS.batch_size) train_iterator = tpu_strategy.make_dataset_iterator(data) tensor_dict = tpu_strategy.experimental_run(_train_step, train_iterator) for k, v in tensor_dict.items(): if k in ('loss', 'error'): # Losses are NOT scaled for num replicas. tensor_dict[k] = tpu_strategy.reduce(tf.distribute.ReduceOp.MEAN, v) else: # Other tensors (hyperparameters) are identical across replicas. # experimental_local_results gives you a tuple of per-replica values. tensor_dict[k] = tpu_strategy.experimental_local_results(v) config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(cluster_resolver.master(), config=config) as session: session.run(tf.global_variables_initializer()) session.run(train_iterator.initializer) print('Starting training.') for step in range(FLAGS.train_steps): values_dict = session.run(tensor_dict) print('Training Step: {}'.format(step)) for k, v in values_dict.items(): print('{}: {}'.format(k, v)) print('Done training.') def main(argv): del argv # Unused. tf.set_random_seed(FLAGS.seed) # Invert using cholesky decomposition + triangular solve. This is the only # code path for matrix inversion supported on TPU right now. kfac.utils.set_global_constants(posdef_inv_method='cholesky') kfac.fisher_factors.set_global_constants( eigenvalue_decomposition_threshold=10000) train() if __name__ == '__main__': tf.disable_v2_behavior() tf.app.run(main) ================================================ FILE: kfac/examples/classifier_mnist.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A simple MNIST classifier example. This script demonstrates training using KFAC optimizer, updating the damping parameter according to the Levenberg-Marquardt rule, and using the quadratic model method for adapting the learning rate and momentum parameters. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math # Dependency imports from absl import flags import kfac import sonnet as snt import tensorflow.compat.v1 as tf from kfac.examples import mnist from kfac.python.ops.kfac_utils import data_reader from kfac.python.ops.kfac_utils import data_reader_alt # Model parameters _NONLINEARITY = tf.nn.relu # can also be tf.nn.tanh _POOL = 'MAX' # can also be 'AVG' flags.DEFINE_integer('train_steps', 10000, 'Number of training steps.') flags.DEFINE_integer('inverse_update_period', 5, '# of steps between computing inverse of Fisher factor ' 'matrices.') flags.DEFINE_integer('cov_update_period', 1, '# of steps between computing covaraiance matrices.') flags.DEFINE_integer('damping_adaptation_interval', 5, '# of steps between updating the damping parameter.') flags.DEFINE_integer('num_burnin_steps', 5, 'Number of steps the at the ' 'start of training where the optimizer will only perform ' 'cov updates. Will not work on CrossShardOptimizer. See ' 'PeriodicInvCovUpdateKfacOpt for details.') flags.DEFINE_integer('seed', 12345, 'Random seed') flags.DEFINE_float('learning_rate', 3e-3, 'Learning rate to use when lrmu_adaptation="off".') flags.DEFINE_float('momentum', 0.9, 'Momentum decay value to use when ' 'lrmu_adaptation="off" or "only_lr".') flags.DEFINE_float('damping', 1e-2, 'The fixed damping value to use. This is ' 'ignored if adapt_damping is True.') flags.DEFINE_float('l2_reg', 1e-5, 'L2 regularization applied to weight matrices.') flags.DEFINE_boolean('update_damping_immediately', True, 'Adapt the damping ' 'immediately after the parameter update (i.e. in the same ' 'sess.run() call). Only safe if everything is a resource ' 'variable.') flags.DEFINE_boolean('use_batch_size_schedule', True, 'If True then we use the growing mini-batch schedule from ' 'the original K-FAC paper.') flags.DEFINE_integer('batch_size', 1024, 'The size of the mini-batches to use if not using the ' 'schedule.') flags.DEFINE_string('lrmu_adaptation', 'on', 'If set to "on" then we use the quadratic model ' 'based learning-rate and momentum adaptation method from ' 'the original paper. Note that this only works well in ' 'practice when use_batch_size_schedule=True. Can also ' 'be set to "off" and "only_lr", which turns ' 'it off, or uses a version where the momentum parameter ' 'is fixed (resp.).') flags.DEFINE_boolean('use_alt_data_reader', True, 'If True we use the alternative data reader for MNIST ' 'that is faster for small datasets.') flags.DEFINE_string('device', '/gpu:0', 'The device to run the major ops on.') flags.DEFINE_boolean('adapt_damping', True, 'If True we use the LM rule for damping adaptation as ' 'described in the original K-FAC paper.') # When using damping adaptation it is advisable to start with a high # value. This value is probably far too high to use for most neural nets # if you aren't using damping adaptation. (Although it always depends on # the scale of the loss.) flags.DEFINE_float('initial_damping', 0.1, 'The initial damping value to use when adapt_damping is ' 'True.') flags.DEFINE_string('optimizer', 'kfac', 'The optimizer to use. Can be kfac or adam. If adam is ' 'used the various kfac hyperparameter map roughly on to ' 'their Adam equivalents.') flags.DEFINE_float('polyak_decay', 0.995, 'Rate of decay for Polyak averaging.') flags.DEFINE_integer('eval_every', 50, 'Interval to print total training loss.') flags.DEFINE_boolean('use_sua_approx', False, 'If True we use the SUA approximation for conv layers.') flags.DEFINE_string('dtype', 'float32', 'The DTYPE to use for all computations. Can by float32 ' 'or float64.') flags.DEFINE_boolean('use_custom_patches_op', False, 'If True we use the custom XLA implementation of the op ' 'which computes the second moment of patch vectors.') FLAGS = flags.FLAGS class Model(snt.AbstractModule): """CNN model for MNIST data.""" def _build(self, inputs): if FLAGS.l2_reg: regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w), 'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),} else: regularizers = None reshape = snt.BatchReshape([28, 28, 1]) conv = snt.Conv2D(2, 5, padding=snt.SAME, regularizers=regularizers) act = _NONLINEARITY(conv(reshape(inputs))) pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL, padding=snt.SAME, strides=(2, 2)) conv = snt.Conv2D(4, 5, padding=snt.SAME, regularizers=regularizers) act = _NONLINEARITY(conv(pool)) pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL, padding=snt.SAME, strides=(2, 2)) flatten = snt.BatchFlatten()(pool) linear = snt.Linear(32, regularizers=regularizers)(flatten) return snt.Linear(10, regularizers=regularizers)(linear) def make_train_op(minibatch, batch_size, batch_loss, layer_collection, loss_fn, prev_train_batch=None, placement_strategy=None, print_logs=False, tf_replicator=None): """Constructs optimizer and train op. Args: minibatch: A list/tuple of Tensors (typically representing the current mini-batch of input images and labels). batch_size: Tensor of shape (). Size of the training mini-batch. batch_loss: Tensor of shape (). Mini-batch loss tensor. layer_collection: LayerCollection object. Registry for model parameters. Required when using a K-FAC optimizer. loss_fn: Function which takes as input a mini-batch and returns the loss. prev_train_batch: `Tensor` of the previous training batch, can be accessed from the data_reader.CachedReader cached_batch property. (Default: None) placement_strategy: `str`, the placement_strategy argument for `KfacOptimizer`. (Default: None) print_logs: `Bool`. If True we print logs using K-FAC's built-in tf.print-based logs printer. (Default: False) tf_replicator: A Replicator object or None. If not None, K-FAC will set itself up to work inside of the provided TF-Replicator object. (Default: None) Returns: train_op: Op that can be used to update model parameters. optimizer: Optimizer used to produce train_op. Raises: ValueError: If layer_collection is None when K-FAC is selected as an optimization method. """ global_step = tf.train.get_or_create_global_step() if FLAGS.optimizer == 'kfac': if FLAGS.lrmu_adaptation == 'on': learning_rate = None momentum = None momentum_type = 'qmodel' elif FLAGS.lrmu_adaptation == 'only_lr': learning_rate = None momentum = FLAGS.momentum momentum_type = 'qmodel_fixedmu' elif FLAGS.lrmu_adaptation == 'off': learning_rate = FLAGS.learning_rate momentum = FLAGS.momentum # momentum_type = 'regular' momentum_type = 'adam' if FLAGS.adapt_damping: damping = FLAGS.initial_damping else: damping = FLAGS.damping optimizer = kfac.PeriodicInvCovUpdateKfacOpt( invert_every=FLAGS.inverse_update_period, cov_update_every=FLAGS.cov_update_period, learning_rate=learning_rate, damping=damping, cov_ema_decay=0.95, momentum=momentum, momentum_type=momentum_type, layer_collection=layer_collection, batch_size=batch_size, num_burnin_steps=FLAGS.num_burnin_steps, adapt_damping=FLAGS.adapt_damping, # Note that many of the arguments below don't do anything when # adapt_damping=False. update_damping_immediately=FLAGS.update_damping_immediately, is_chief=True, prev_train_batch=prev_train_batch, loss=batch_loss, loss_fn=loss_fn, damping_adaptation_decay=0.9, damping_adaptation_interval=FLAGS.damping_adaptation_interval, min_damping=1e-6, l2_reg=FLAGS.l2_reg, train_batch=minibatch, placement_strategy=placement_strategy, print_logs=print_logs, tf_replicator=tf_replicator, dtype=FLAGS.dtype, ) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.learning_rate, beta1=FLAGS.momentum, epsilon=FLAGS.damping, beta2=0.99) return optimizer.minimize(batch_loss, global_step=global_step), optimizer def compute_loss(logits=None, labels=None, return_error=False, use_regularizer=True): """Compute loss value.""" loss_matrix = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) total_loss = tf.reduce_mean(loss_matrix, axis=0) if use_regularizer: graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_regularization_loss = tf.add_n(graph_regularizers) total_loss += tf.cast(total_regularization_loss, dtype=total_loss.dtype) if return_error: error = 1.0 - tf.reduce_mean(tf.cast( tf.equal(labels, tf.argmax(logits, axis=1, output_type=tf.int32)), tf.float32)) return total_loss, error return total_loss def load_mnist(): """Creates MNIST dataset and wraps it inside cached data reader. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. num_examples: int. The number of training examples. """ # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. if not FLAGS.use_alt_data_reader: # Version 1 using data_reader.py (slow!) dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True) if FLAGS.use_batch_size_schedule: max_batch_size = num_examples else: max_batch_size = FLAGS.batch_size # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat() .batch(max_batch_size).prefetch(5)) dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # This version of CachedDataReader requires the dataset to be shuffled return data_reader.CachedDataReader(dataset, max_batch_size), num_examples else: # Version 2 using data_reader_alt.py (faster) images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype)) dataset = (images, labels) # This version of CachedDataReader requires the dataset to NOT be shuffled return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples def _get_batch_size_schedule(num_examples): """Returns training batch size schedule.""" minibatch_maxsize_targetiter = 100 # We use a smaller target iter here than # in the autoencoder example. minibatch_maxsize = num_examples minibatch_startsize = 1000 div = (float(minibatch_maxsize_targetiter-1) / math.log(float(minibatch_maxsize)/minibatch_startsize, 2)) return [ min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize) for k in range(minibatch_maxsize_targetiter) ] def group_assign(dest, source): return tf.group(*(d.assign(s) for d, s in zip(dest, source))) def make_eval_ops(train_vars, ema): # This does evaluation with and without Polyak averaging. images, labels, _ = mnist.load_mnist_as_tensors( flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype)) eval_model = Model() eval_model(images) # We need this dummy call because the variables won't # exist otherwise. eval_vars = eval_model.variables update_eval_model = group_assign(eval_vars, train_vars) with tf.control_dependencies([update_eval_model]): logits = eval_model(images) eval_loss, eval_error = compute_loss( logits=logits, labels=labels, return_error=True) with tf.control_dependencies([eval_loss, eval_error]): update_eval_model_avg = group_assign( eval_vars, (ema.average(t) for t in train_vars)) with tf.control_dependencies([update_eval_model_avg]): logits = eval_model(images) eval_loss_avg, eval_error_avg = compute_loss( logits=logits, labels=labels, return_error=True) return eval_loss, eval_error, eval_loss_avg, eval_error_avg def construct_train_quants(): with tf.device(FLAGS.device): # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') minibatch = cached_reader(batch_size) features, _ = minibatch training_model = Model() layer_collection = kfac.LayerCollection() if FLAGS.use_sua_approx: layer_collection.set_default_conv2d_approximation('kron_sua') ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay, zero_debias=True) def loss_fn(minibatch, logits=None, return_error=False): features, labels = minibatch if logits is None: logits = training_model(features) return compute_loss( logits=logits, labels=labels, return_error=return_error) logits = training_model(features) (batch_loss, batch_error) = loss_fn( minibatch, logits=logits, return_error=True) # Make sure never to confuse this with register_sigmoid_cross_entropy_loss! layer_collection.register_softmax_cross_entropy_loss(logits, seed=FLAGS.seed + 1) layer_collection.auto_register_layers() train_vars = training_model.variables # Make training op: train_op, opt = make_train_op( minibatch, batch_size, batch_loss, layer_collection, loss_fn=loss_fn, prev_train_batch=cached_reader.cached_batch) with tf.control_dependencies([train_op]): train_op = ema.apply(train_vars) # We clear out the regularizers collection when creating our evaluation # graph (which uses different variables). It is important that we do this # only after the train op is constructed, since the minimize() will call # into the loss function (which includes the regularizer): tf.get_default_graph().clear_collection(tf.GraphKeys.REGULARIZATION_LOSSES) # These aren't run in the same sess.run call as train_op: (eval_loss, eval_error, eval_loss_avg, eval_error_avg) = make_eval_ops(train_vars, ema) return (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg) def main(_): # If using update_damping_immediately resource variables must be enabled. if FLAGS.update_damping_immediately: tf.enable_resource_variables() if not FLAGS.use_sua_approx: if FLAGS.use_custom_patches_op: kfac.fisher_factors.set_global_constants( use_patches_second_moment_op=True ) else: # Temporary measure to save memory with giant batches: kfac.fisher_factors.set_global_constants( sub_sample_inputs=True, inputs_to_extract_patches_factor=0.2) tf.set_random_seed(FLAGS.seed) (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg) = construct_train_quants() global_step = tf.train.get_or_create_global_step() if FLAGS.optimizer == 'kfac': # We need to put the control depenency on train_op here so that we are # guaranteed to get the up-to-date values of these various quantities. # Otherwise there is a race condition and we might get the old values, # nondeterministically. Another solution would be to get these values in # a separate sess.run call, but this can sometimes cause problems with # training frameworks that use hooks (see the comments below). with tf.control_dependencies([train_op]): learning_rate = opt.learning_rate momentum = opt.momentum damping = opt.damping rho = opt.rho qmodel_change = opt.qmodel_change # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # Train model. # It's good practice to put everything into a single sess.run call. The # reason is that certain "training frameworks" like to run hooks at each # sess.run call, and there is an implicit expectation there will only # be one sess.run call every "iteration" of the "optimizer". For example, # a framework might try to print the loss at each sess.run call, causing # the mini-batch to be advanced, thus completely breaking the "cached # batch" mechanism that the damping adaptation method may rely on. (Plus # there will also be the extra cost of having to reevaluate the loss # twice.) That being said we don't completely do that here because it's # inconvenient. with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30, config=config) as sess: for _ in range(FLAGS.train_steps): i = sess.run(global_step) if FLAGS.use_batch_size_schedule: batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)] else: batch_size_ = FLAGS.batch_size if FLAGS.optimizer == 'kfac': (_, batch_loss_, batch_error_, learning_rate_, momentum_, damping_, rho_, qmodel_change_) = sess.run([train_op, batch_loss, batch_error, learning_rate, momentum, damping, rho, qmodel_change], feed_dict={batch_size: batch_size_}) else: _, batch_loss_, batch_error_ = sess.run( [train_op, batch_loss, batch_error], feed_dict={batch_size: batch_size_}) # Print training stats. tf.logging.info( 'iteration: %d', i) tf.logging.info( 'mini-batch size: %d | mini-batch loss = %f | mini-batch error = %f ', batch_size_, batch_loss_, batch_error_) if FLAGS.optimizer == 'kfac': tf.logging.info( 'learning_rate = %f | momentum = %f', learning_rate_, momentum_) tf.logging.info( 'damping = %f | rho = %f | qmodel_change = %f', damping_, rho_, qmodel_change_) # "Eval" here means just compute stuff on the full training set. if (i+1) % FLAGS.eval_every == 0: eval_loss_, eval_error_, eval_loss_avg_, eval_error_avg_ = sess.run( [eval_loss, eval_error, eval_loss_avg, eval_error_avg]) tf.logging.info('-----------------------------------------------------') tf.logging.info('eval_loss = %f | eval_error = %f', eval_loss_, eval_error_) tf.logging.info('eval_loss_avg = %f | eval_error_avg = %f', eval_loss_avg_, eval_error_avg_) tf.logging.info('-----------------------------------------------------') else: tf.logging.info('----') if __name__ == '__main__': tf.disable_v2_behavior() tf.app.run(main) ================================================ FILE: kfac/examples/classifier_mnist_tpu_estimator.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A simple MNIST classifier example. This script demonstrates training on TPUs with TPU Estimator using the KFAC optimizer, updating the damping parameter according to the Levenberg-Marquardt rule, and using the quadratic model method for adapting the learning rate and momentum parameters. See third_party/tensorflow_kfac/google/examples/classifier_tpu_xm_launcher.py for an example Borg launch script. If you can't access this launch script, some important things to know about running K-FAC on TPUs (at least for this example) are that you must use higher-precision matrix multiplications. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports from absl import flags import kfac import tensorflow.compat.v1 as tf from tensorflow.contrib import tpu as contrib_tpu from kfac.examples import classifier_mnist from kfac.examples import mnist flags.DEFINE_integer('save_checkpoints_steps', 500, 'Number of iterations between model checkpoints.') flags.DEFINE_integer('iterations_per_loop', 100, 'Number of iterations in a TPU training loop.') flags.DEFINE_string('model_dir', '', 'Model dir.') flags.DEFINE_string('master', None, 'GRPC URL of the master ' '(e.g. grpc://ip.address.of.tpu:8470).') FLAGS = flags.FLAGS def make_train_op(minibatch, batch_loss, layer_collection, loss_fn): """Constructs optimizer and train op. Args: minibatch: Tuple[Tensor, Tensor] representing the current batch of input images and labels. batch_loss: Tensor of shape (), Loss with respect to minibatch to be minimzed. layer_collection: LayerCollection object. Registry for model parameters. Required when using a K-FAC optimizer. loss_fn: A function that when called constructs the graph to compute the model loss on the current minibatch. Returns a Tensor of the loss scalar. Returns: train_op: Op that can be used to update model parameters. optimizer: The KFAC optimizer used to produce train_op. Raises: ValueError: If layer_collection is None when K-FAC is selected as an optimization method. """ # Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own # cross-replica syncronization automatically! return classifier_mnist.make_train_op( minibatch=minibatch, batch_size=minibatch[0].get_shape().as_list()[0], batch_loss=batch_loss, layer_collection=layer_collection, loss_fn=loss_fn, prev_train_batch=None, placement_strategy='replica_round_robin', ) def mnist_input_fn(params): dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True) # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = ( dataset.shuffle(num_examples).repeat().batch( params['batch_size'], drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)) return dataset def print_tensors(**tensors): """Host call function to print Tensors from the TPU during training.""" print_op = tf.no_op() for name in sorted(tensors): with tf.control_dependencies([print_op]): tensor = tensors[name] if name in ['error', 'loss']: tensor = tf.reduce_mean(tensor) print_op = tf.Print(tensor, [tensor], message=name + '=') with tf.control_dependencies([print_op]): return tf.Print(0., [0.], message='------') def _model_fn(features, labels, mode, params): """Estimator model_fn for an autoencoder with adaptive damping.""" del params training_model = classifier_mnist.Model() layer_collection = kfac.LayerCollection() def loss_fn(minibatch, logits=None, return_error=False): features, labels = minibatch if logits is None: # Note we do not need to do anything like # `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):` # here because Sonnet takes care of variable reuse for us as long as we # call the same `training_model` module. Otherwise we would need to # use variable reusing here. logits = training_model(features) return classifier_mnist.compute_loss(logits=logits, labels=labels, return_error=return_error) logits = training_model(features) pre_update_batch_loss, pre_update_batch_error = loss_fn( (features, labels), logits=logits, return_error=True) global_step = tf.train.get_or_create_global_step() if mode == tf.estimator.ModeKeys.TRAIN: layer_collection.register_softmax_cross_entropy_loss(logits, seed=FLAGS.seed + 1) layer_collection.auto_register_layers() train_op, kfac_optimizer = make_train_op( (features, labels), pre_update_batch_loss, layer_collection, loss_fn) tensors_to_print = { 'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0), 'momentum': tf.expand_dims(kfac_optimizer.momentum, 0), 'damping': tf.expand_dims(kfac_optimizer.damping, 0), 'global_step': tf.expand_dims(global_step, 0), 'loss': tf.expand_dims(pre_update_batch_loss, 0), 'error': tf.expand_dims(pre_update_batch_error, 0), } if FLAGS.adapt_damping: tensors_to_print['qmodel_change'] = tf.expand_dims( kfac_optimizer.qmodel_change, 0) tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0) return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=pre_update_batch_loss, train_op=train_op, host_call=(print_tensors, tensors_to_print), eval_metrics=None) else: # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}: return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=pre_update_batch_loss, eval_metrics=None) def make_tpu_run_config(master, seed, model_dir, iterations_per_loop, save_checkpoints_steps): return contrib_tpu.RunConfig( master=master, evaluation_master=master, model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps, cluster=None, tf_random_seed=seed, tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop)) def main(argv): del argv # Unused. # If using update_damping_immediately resource variables must be enabled. # (Although they probably will be by default on TPUs.) if FLAGS.update_damping_immediately: tf.enable_resource_variables() tf.set_random_seed(FLAGS.seed) # Invert using cholesky decomposition + triangular solve. This is the only # code path for matrix inversion supported on TPU right now. kfac.utils.set_global_constants(posdef_inv_method='cholesky') kfac.fisher_factors.set_global_constants( eigenvalue_decomposition_threshold=10000) if not FLAGS.use_sua_approx: if FLAGS.use_custom_patches_op: kfac.fisher_factors.set_global_constants( use_patches_second_moment_op=True ) else: # Temporary measure to save memory with giant batches: kfac.fisher_factors.set_global_constants( sub_sample_inputs=True, inputs_to_extract_patches_factor=0.1) config = make_tpu_run_config( FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop, FLAGS.save_checkpoints_steps) estimator = contrib_tpu.TPUEstimator( use_tpu=True, model_fn=_model_fn, config=config, train_batch_size=FLAGS.batch_size, eval_batch_size=1024) estimator.train( input_fn=mnist_input_fn, max_steps=FLAGS.train_steps, hooks=[]) if __name__ == '__main__': tf.disable_v2_behavior() tf.app.run(main) ================================================ FILE: kfac/examples/convnet.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Train a ConvNet on MNIST using K-FAC. This library demonstrates how to use K-FAC to train a 5-layer ConvNet on MNIST using K-FAC. Note that this example is basically untuned and is not meant to work as an actual demonstration of the power of the method. It may not even converge. It merely demonstrates the how to set up K-FAC to run under the various standard modes of operation in Tensorflow, like SyncReplicas, Estimator, etc. For an example of the method tuned properly and working well, see for example the autoencoder_mnist.py example, which replicates the exact experiment from the original K-FAC paper. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import kfac import numpy as np import tensorflow.compat.v1 as tf from kfac.examples import mnist __all__ = [ "conv_layer", "fc_layer", "max_pool_layer", "build_model", "minimize_loss_single_machine", "distributed_grads_only_and_ops_chief_worker", "distributed_grads_and_ops_dedicated_workers", "train_mnist_single_machine", "train_mnist_distributed_sync_replicas", "train_mnist_multitower" ] # Inverse update ops will be run every _INVERT_EVRY iterations. _INVERT_EVERY = 10 # Covariance matrices will be update _COV_UPDATE_EVERY iterations. _COV_UPDATE_EVERY = 1 # Displays loss every _REPORT_EVERY iterations. _REPORT_EVERY = 10 # Use manual registration _USE_MANUAL_REG = False def fc_layer(layer_id, inputs, output_size): """Builds a fully connected layer. Args: layer_id: int. Integer ID for this layer's variables. inputs: Tensor of shape [num_examples, input_size]. Each row corresponds to a single example. output_size: int. Number of output dimensions after fully connected layer. Returns: preactivations: Tensor of shape [num_examples, output_size]. Values of the layer immediately before the activation function. activations: Tensor of shape [num_examples, output_size]. Values of the layer immediately after the activation function. params: Tuple of (weights, bias), parameters for this layer. """ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. layer = tf.layers.Dense( output_size, kernel_initializer=tf.random_normal_initializer(), name="fc_%d" % layer_id) preactivations = layer(inputs) activations = tf.nn.tanh(preactivations) # layer.weights is a list. This converts it a (hashable) tuple. return preactivations, activations, (layer.kernel, layer.bias) def conv_layer(layer_id, inputs, kernel_size, out_channels): """Builds a convolutional layer with ReLU non-linearity. Args: layer_id: int. Integer ID for this layer's variables. inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row corresponds to a single example. kernel_size: int. Width and height of the convolution kernel. The kernel is assumed to be square. out_channels: int. Number of output features per pixel. Returns: preactivations: Tensor of shape [num_examples, width, height, out_channels]. Values of the layer immediately before the activation function. activations: Tensor of shape [num_examples, width, height, out_channels]. Values of the layer immediately after the activation function. params: Tuple of (kernel, bias), parameters for this layer. """ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. layer = tf.layers.Conv2D( out_channels, kernel_size=[kernel_size, kernel_size], kernel_initializer=tf.random_normal_initializer(stddev=0.01), padding="SAME", name="conv_%d" % layer_id) preactivations = layer(inputs) activations = tf.nn.relu(preactivations) # layer.weights is a list. This converts it a (hashable) tuple. return preactivations, activations, (layer.kernel, layer.bias) def max_pool_layer(layer_id, inputs, kernel_size, stride): """Build a max-pooling layer. Args: layer_id: int. Integer ID for this layer's variables. inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row corresponds to a single example. kernel_size: int. Width and height to pool over per input channel. The kernel is assumed to be square. stride: int. Step size between pooling operations. Returns: Tensor of shape [num_examples, width/stride, height/stride, out_channels]. Result of applying max pooling to 'inputs'. """ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. with tf.variable_scope("pool_%d" % layer_id): return tf.nn.max_pool( inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1], padding="SAME", name="pool") def build_model(examples, labels, num_labels, layer_collection, register_layers_manually=False): """Builds a ConvNet classification model. Args: examples: Tensor of shape [num_examples, num_features]. Represents inputs of model. labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted by softmax for each example. num_labels: int. Number of distinct values 'labels' can take on. layer_collection: LayerCollection instance. Layers will be registered here. register_layers_manually: bool. If True then register the layers with layer_collection manually. (Default: False) Returns: loss: 0-D Tensor representing loss to be minimized. accuracy: 0-D Tensor representing model's accuracy. """ # Build a ConvNet. For each layer with parameters, we'll keep track of the # preactivations, activations, weights, and bias. tf.logging.info("Building model.") pre0, act0, params0 = conv_layer( layer_id=0, inputs=examples, kernel_size=5, out_channels=16) act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) pre2, act2, params2 = conv_layer( layer_id=2, inputs=act1, kernel_size=5, out_channels=16) act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) logits, _, params4 = fc_layer( layer_id=4, inputs=flat_act3, output_size=num_labels) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits)) accuracy = tf.reduce_mean( tf.cast(tf.equal(tf.cast(labels, dtype=tf.int32), tf.argmax(logits, axis=1, output_type=tf.int32)), dtype=tf.float32)) with tf.device("/cpu:0"): tf.summary.scalar("loss", loss) tf.summary.scalar("accuracy", accuracy) layer_collection.register_softmax_cross_entropy_loss( logits, name="logits") if register_layers_manually: layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, pre0) layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) layer_collection.register_fully_connected(params4, flat_act3, logits) return loss, accuracy def minimize_loss_single_machine(loss, accuracy, layer_collection, device=None, session_config=None): """Minimize loss with K-FAC on a single machine. Creates `PeriodicInvCovUpdateKfacOpt` which handles inverse and covariance computation op placement and execution. A single Session is responsible for running all of K-FAC's ops. The covariance and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. device: string or None. The covariance and inverse update ops are run on this device. If empty or None, the default device will be used. (Default: None) session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ device_list = [] if not device else [device] # Train with K-FAC. g_step = tf.train.get_or_create_global_step() optimizer = kfac.PeriodicInvCovUpdateKfacOpt( invert_every=_INVERT_EVERY, cov_update_every=_COV_UPDATE_EVERY, learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, placement_strategy="round_robin", cov_devices=device_list, inv_devices=device_list, trans_devices=device_list, momentum=0.9) with tf.device(device): train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( [g_step, loss, accuracy, train_op]) if global_step_ % _REPORT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) return accuracy_ def minimize_loss_single_machine_manual(loss, accuracy, layer_collection, device=None, session_config=None): """Minimize loss with K-FAC on a single machine(Illustrative purpose only). This function does inverse and covariance computation manually for illustrative pupose. Check `minimize_loss_single_machine` for automatic inverse and covariance op placement and execution. A single Session is responsible for running all of K-FAC's ops. The covariance and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. device: string or None. The covariance and inverse update ops are run on this device. If empty or None, the default device will be used. (Default: None) session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ device_list = [] if not device else [device] # Train with K-FAC. g_step = tf.train.get_or_create_global_step() optimizer = kfac.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, placement_strategy="round_robin", cov_devices=device_list, inv_devices=device_list, trans_devices=device_list, momentum=0.9) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() def make_update_op(update_thunks): update_ops = [thunk() for thunk in update_thunks] return tf.group(*update_ops) cov_update_op = make_update_op(cov_update_thunks) with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) with tf.control_dependencies([inverse_op]): with tf.device(device): train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( [g_step, loss, accuracy, train_op]) if global_step_ % _REPORT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) return accuracy_ def _is_gradient_task(task_id, num_tasks): """Returns True if this task should update the weights.""" if num_tasks < 3: return True return 0 <= task_id < 0.6 * num_tasks def _is_cov_update_task(task_id, num_tasks): """Returns True if this task should update K-FAC's covariance matrices.""" if num_tasks < 3: return False return 0.6 * num_tasks <= task_id < num_tasks - 1 def _is_inv_update_task(task_id, num_tasks): """Returns True if this task should update K-FAC's preconditioner.""" if num_tasks < 3: return False return task_id == num_tasks - 1 def _num_gradient_tasks(num_tasks): """Number of tasks that will update weights.""" if num_tasks < 3: return num_tasks return int(np.ceil(0.6 * num_tasks)) def _make_distributed_train_op( task_id, num_worker_tasks, num_ps_tasks, layer_collection ): """Creates optimizer and distributed training op. Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes the train op. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. Returns: sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC optimizer. optimizer: Instance of `KfacOptimizer`. global_step: `tensor`, Global step. """ tf.logging.info("Task id : %d", task_id) with tf.device(tf.train.replica_device_setter(num_ps_tasks)): global_step = tf.train.get_or_create_global_step() optimizer = kfac.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, momentum=0.9) sync_optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), total_num_replicas=num_worker_tasks) return sync_optimizer, optimizer, global_step def distributed_grads_only_and_ops_chief_worker( task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection, invert_every=10): """Minimize loss with a synchronous implementation of K-FAC. All workers perform gradient computation. Chief worker applies gradient after averaging the gradients obtained from all the workers. All workers block execution until the update is applied. Chief worker runs covariance and inverse update ops. Covariance and inverse matrices are placed on parameter servers in a round robin manner. For further details on synchronous distributed optimization check `tf.train.SyncReplicasOptimizer`. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. master: string. IP and port of TensorFlow runtime process. Set to empty string to run locally. checkpoint_dir: string or None. Path to store checkpoints under. loss: 0-D Tensor. Loss to be minimized. accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. invert_every: `int`, Number of steps between update the inverse. Returns: final value for 'accuracy'. Raises: ValueError: if task_id >= num_worker_tasks. """ sync_optimizer, optimizer, global_step = _make_distributed_train_op( task_id, num_worker_tasks, num_ps_tasks, layer_collection) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() tf.logging.info("Starting training.") hooks = [sync_optimizer.make_session_run_hook(is_chief)] def make_update_op(update_thunks): update_ops = [thunk() for thunk in update_thunks] return tf.group(*update_ops) if is_chief: cov_update_op = make_update_op(cov_update_thunks) with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( tf.equal(tf.mod(global_step, invert_every), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) with tf.control_dependencies([inverse_op]): train_op = sync_optimizer.minimize(loss, global_step=global_step) else: train_op = sync_optimizer.minimize(loss, global_step=global_step) # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) with tf.train.MonitoredTrainingSession( master=master, is_chief=is_chief, checkpoint_dir=checkpoint_dir, hooks=hooks, stop_grace_period_secs=0, config=config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( [global_step, loss, accuracy, train_op]) tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) return accuracy_ def distributed_grads_and_ops_dedicated_workers( task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection): """Minimize loss with a synchronous implementation of K-FAC. Different workers are responsible for different parts of K-FAC's Ops. The first 60% of tasks compute gradients; the next 20% accumulate covariance statistics; the last 20% invert the matrices used to precondition gradients. The chief worker computes and applies the update. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. master: string. IP and port of TensorFlow runtime process. Set to empty string to run locally. checkpoint_dir: string or None. Path to store checkpoints under. loss: 0-D Tensor. Loss to be minimized. accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. Returns: final value for 'accuracy'. Raises: ValueError: if task_id >= num_worker_tasks. """ sync_optimizer, optimizer, global_step = _make_distributed_train_op( task_id, num_worker_tasks, num_ps_tasks, layer_collection) _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() train_op = sync_optimizer.minimize(loss, global_step=global_step) inv_update_queue = kfac.op_queue.OpQueue(inv_update_ops) tf.logging.info("Starting training.") is_chief = (task_id == 0) hooks = [sync_optimizer.make_session_run_hook(is_chief)] # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) with tf.train.MonitoredTrainingSession( master=master, is_chief=is_chief, checkpoint_dir=checkpoint_dir, hooks=hooks, stop_grace_period_secs=0, config=config) as sess: while not sess.should_stop(): # Choose which op this task is responsible for running. if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): learning_op = cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): learning_op = inv_update_queue.next_op(sess) else: raise ValueError("Which op should task %d do?" % task_id) global_step_, loss_, accuracy_, _ = sess.run( [global_step, loss, accuracy, learning_op]) tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) return accuracy_ def train_mnist_single_machine(num_epochs, use_fake_data=False, device=None, manual_op_exec=False): """Train a ConvNet on MNIST. Args: num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. device: string or None. The covariance and inverse update ops are run on this device. If empty or None, the default device will be used. (Default: None) manual_op_exec: bool, If `True` then `minimize_loss_single_machine_manual` is called for training which handles inverse and covariance computation. This is shown only for illustrative purpose. Otherwise `minimize_loss_single_machine` is called which relies on `PeriodicInvCovUpdateOpt` for op placement and execution. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") (examples, labels) = mnist.load_mnist_as_iterator(num_epochs, 128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model( examples, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # Fit model. if manual_op_exec: return minimize_loss_single_machine_manual( loss, accuracy, layer_collection, device=device, session_config=config) else: return minimize_loss_single_machine( loss, accuracy, layer_collection, device=device, session_config=config) def train_mnist_multitower(num_epochs, num_towers, devices, use_fake_data=False, session_config=None): """Train a ConvNet on MNIST. Training data is split equally among the towers. Each tower computes loss on its own batch of data and the loss is aggregated on the CPU. The model variables are placed on first tower. The covariance and inverse update ops and variables are placed on specified devices in a round robin manner. Args: num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of towers. devices: list of strings. List of devices to place the towers. use_fake_data: bool. If True, generate a synthetic dataset. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: accuracy of model on the final minibatch of training data. """ num_towers = 1 if not devices else len(devices) # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 batch_size = tower_batch_size * num_towers tf.logging.info( ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " "tower batch size.") % (batch_size, num_towers, tower_batch_size)) (examples, labels) = mnist.load_mnist_as_iterator(num_epochs, batch_size, use_fake_data=use_fake_data, flatten_images=False) # Split minibatch across towers. examples = tf.split(examples, num_towers) labels = tf.split(labels, num_towers) # Build an MLP. Each tower's layers will be added to the LayerCollection. layer_collection = kfac.LayerCollection() tower_results = [] for tower_id in range(num_towers): with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) tower_results.append( build_model( examples[tower_id], labels[tower_id], 10, layer_collection, register_layers_manually=_USE_MANUAL_REG)) losses, accuracies = zip(*tower_results) # When using multiple towers we only want to perform automatic # registation once, after the final tower is made if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Average across towers. loss = tf.reduce_mean(losses) accuracy = tf.reduce_mean(accuracies) # Fit model. g_step = tf.train.get_or_create_global_step() optimizer = kfac.PeriodicInvCovUpdateKfacOpt( invert_every=_INVERT_EVERY, cov_update_every=_COV_UPDATE_EVERY, learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, placement_strategy="round_robin", cov_devices=devices, inv_devices=devices, trans_devices=devices, momentum=0.9) with tf.device(devices[0]): train_op = optimizer.minimize(loss, global_step=g_step) # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). if not session_config: session_config = tf.ConfigProto(allow_soft_placement=True) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( [g_step, loss, accuracy, train_op]) if global_step_ % _REPORT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) def train_mnist_distributed_sync_replicas(task_id, is_chief, num_worker_tasks, num_ps_tasks, master, num_epochs, op_strategy, use_fake_data=False): """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. num_epochs: int. Number of passes to make over the training set. op_strategy: `string`, Strategy to run the covariance and inverse ops. If op_strategy == `chief_worker` then covariance and inverse update ops are run on chief worker otherwise they are run on dedicated workers. use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. Raises: ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") (examples, labels) = mnist.load_mnist_as_iterator(num_epochs, 128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() with tf.device(tf.train.replica_device_setter(num_ps_tasks)): loss, accuracy = build_model( examples, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Fit model. checkpoint_dir = None if op_strategy == "chief_worker": return distributed_grads_only_and_ops_chief_worker( task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection) elif op_strategy == "dedicated_workers": return distributed_grads_and_ops_dedicated_workers( task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection) else: raise ValueError("Only supported op strategies are : {}, {}".format( "chief_worker", "dedicated_workers")) def train_mnist_estimator(num_epochs, use_fake_data=False): """Train a ConvNet on MNIST using tf.estimator. Args: num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. def input_fn(): tf.logging.info("Loading MNIST into memory.") return mnist.load_mnist_as_iterator(num_epochs=num_epochs, batch_size=64, flatten_images=False, use_fake_data=use_fake_data) def model_fn(features, labels, mode, params): """Model function for MLP trained with K-FAC. Args: features: Tensor of shape [batch_size, input_size]. Input features. labels: Tensor of shape [batch_size]. Target labels for training. mode: tf.estimator.ModeKey. Must be TRAIN. params: ignored. Returns: EstimatorSpec for training. Raises: ValueError: If 'mode' is anything other than TRAIN. """ del params if mode != tf.estimator.ModeKeys.TRAIN: raise ValueError("Only training is supported with this API.") # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model( features, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Train with K-FAC. global_step = tf.train.get_or_create_global_step() optimizer = kfac.KfacOptimizer( learning_rate=tf.train.exponential_decay( 0.00002, global_step, 10000, 0.5, staircase=True), cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, momentum=0.9) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() def make_update_op(update_thunks): update_ops = [thunk() for thunk in update_thunks] return tf.group(*update_ops) def make_batch_executed_op(update_thunks, batch_size=1): return tf.group(*kfac.utils.batch_execute( global_step, update_thunks, batch_size=batch_size)) # Run cov_update_op every step. Run 1 inv_update_ops per step. cov_update_op = make_update_op(cov_update_thunks) with tf.control_dependencies([cov_update_op]): # But make sure to execute all the inverse ops on the first step inverse_op = tf.cond(tf.equal(global_step, 0), lambda: make_update_op(inv_update_thunks), lambda: make_batch_executed_op(inv_update_thunks)) with tf.control_dependencies([inverse_op]): train_op = optimizer.minimize(loss, global_step=global_step) # Print metrics every 5 sec. hooks = [ tf.train.LoggingTensorHook( { "loss": loss, "accuracy": accuracy }, every_n_secs=5), ] return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) run_config = tf.estimator.RunConfig( model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) # Train until input_fn() is empty with Estimator. This is a prerequisite for # TPU compatibility. estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) estimator.train(input_fn=input_fn) ================================================ FILE: kfac/examples/keras/KFAC_vs_Adam_Experiment.md ================================================ # KFAC vs Adam Experiment ## Set Up We compare KFAC and Adam on a RESNET-20 on the CIFAR10 dataset. We split CIFAR10 into a training (40k), validation (10k), and test (10k) sets. We ran a random hyperparameter search where the best hyperparameters were chosen by the run that reaches 89% validation accuracy first in terms of number of steps. We decay both learning rate and damping/epsilon exponentially. The final learning rate is fixed at 1e-4, final damping (KFAC) at 1e-6, and final epsilon (Adam) at 1e-8. Below are the ranges of the tuned hyperparamters. The random search samples all the hyperparameters from a log uniform scale: | Hyperparameter | Min | Max | |---------------------------|------|-------| | Init Learning Rate | 1e-2 | 10.0 | | Init Damping (KFAC) | 1e-2 | 100.0 | | Init Epsilon (Adam) | 1e-4 | 1.0 | | 1 - Learning Rate Decay | 1e-4 | 0.1 | | 1 - Damping/Epsilon Decay | 1e-4 | 0.1 | | 1 - Momentum | 1e-2 | 0.3 | The initial tuning run was with seed 20190524 with the GPU training script on an NVIDIA Tesla P100. Then, after choosing the best hyperparameters, we ran those hyperparameters with the following 10 random seeds: 351515, 382980, 934126, 891369, 64379, 402680, 672242, 421590, 498163, 448799. # Results The chosen hyperparameters were the following (to 6 decimal places): | Hyperparameter | KFAC | Adam | |---------------------------|----------|----------| | Init Learning Rate | 0.227214 | 2.242663 | | Init Damping (KFAC) | 0.288721 | | | Init Epsilon (Adam) | | 0.183230 | | 1 - Learning Rate Decay | 0.001090 | 0.000610 | | 1 - Damping/Epsilon Decay | 0.000287 | 0.000213 | | 1 - Momentum | 0.018580 | 0.029656 | ## Training Curves Below are the loss and accuracy training curves with the training and test sets. The line represents the mean of the 10 seed runs and the coloured region represents the bootstrapped standard deviation. KFAC reaches 89% validation accuracy at step 4640 and Adam at step 6560 (measurements were taken every 40 steps). ![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_v_adam_loss_curve.png?raw=true) ![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_v_adam_accuracy_curve.png?raw=true) Among the other runs, KFAC decreases training loss quicker than Adam early in training, then show similar performance later in training. ## Hyperparameter Analysis We offer some analysis of the learning rate and damping for KFAC to aid in choosing appropriate values for these hyperparameters. Plots with the rest of the hyperparameters for both KFAC and Adam are in the plots folder. ![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_lr_v_damping.png?raw=true) In general, a higher learning rate requires a higher damping. A large learning rate with low damping leads to divergence, whereas a low learning rate with high damping leads to SGD-like behaviour, which is suboptimal. The plot above shows little correlation due to the decay schedules playing a large role, which is shown below: ![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_damping_v_dampingdecay.png?raw=true) A fast damping decay allows for faster training, but can easily lead to divergence. The best runs are often close to diverging. ![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_lr_v_lrdecay.png?raw=true) As expected, a high learning rate with a low decay can lead to divergence. ![](https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/plots/kfac_lrdecay_v_dampingdecay.png?raw=true) Just like with the learning rate and damping, the learning rate decay should be proportional the damping decay to prevent divergence while training quickly. ================================================ FILE: kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_DDaAex5Q7u-" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "cellView": "both", "colab": {}, "colab_type": "code", "id": "W1dWWdNHQ9L0" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_C170SDp6jBt" }, "source": [ "# KFAC vs Adam on CIFAR10 on a GPU\n", "\n", "This notebook contains the code used to run the experiment comparing KFAC and Adam on CIFAR 10 with a Resnet-20. This was run on a NVIDIA Tesla P100 for the experiment. It can be run on a public GPU colab instance.\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorflow/kfac/blob/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "rw0qz2RWkLeJ" }, "outputs": [], "source": [ "!pip install kfac" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "LfGyhnaOsgYu" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "import math\n", "import kfac" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "DYWIY0C380ye" }, "outputs": [], "source": [ "TRAINING_SIZE = 40000\n", "VALIDATION_SIZE = 10000\n", "TEST_SIZE = 10000\n", "SEED = 20190524\n", "\n", "num_training_steps = 7500\n", "batch_size = 1000\n", "layers = tf.keras.layers\n", "\n", "# We take the ceiling because we do not drop the remainder of the batch\n", "compute_steps_per_epoch = lambda x: int(math.ceil(1. * x / batch_size))\n", "steps_per_epoch = compute_steps_per_epoch(TRAINING_SIZE)\n", "val_steps = compute_steps_per_epoch(VALIDATION_SIZE)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "GfeTgsbh5G4g" }, "outputs": [], "source": [ "optimizer_name = 'kfac' # 'kfac' or 'adam'\n", "\n", "# Best Hyperparameters from the Random Search\n", "if optimizer_name == 'kfac':\n", " init_learning_rate = 0.22721400059936694\n", " final_learning_rate = 1e-04\n", " init_damping = 0.28872127217018184\n", " final_damping = 1e-6\n", " momentum = 1 - 0.018580394981260295\n", " lr_decay_rate = 1 - 0.001090107322908028\n", " damping_decay_rate = 1 - 0.0002870880729016523\n", "elif optimizer_name == 'adam':\n", " init_learning_rate = 2.24266320779\n", " final_learning_rate = 1e-4\n", " init_epsilon = 0.183230038808\n", " final_epsilon = 1e-8\n", " momentum = 1 - 0.0296561513388\n", " lr_decay_rate = 1 - 0.000610416031571\n", " epsilon_decay_rate = 1 - 0.000212682338199\n", "else:\n", " raise ValueError('Ensure optimizer_name is kfac or adam')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "v3vSki-usp9k" }, "source": [ "## Input Pipeline" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "D2U3i5kgssy_" }, "outputs": [], "source": [ "def _parse_fn(x):\n", " image, label = x['image'], x['label']\n", " image = tf.cast(image, tf.float32)\n", " label = tf.cast(label, tf.int32)\n", " image = image / 127.5 - 1\n", " return image, label\n", "\n", "\n", "def _augment_image(image, crop_amount, seed=None):\n", " # Random Brightness, Contrast, Jpeg Quality, Hue, and Saturation did not\n", " # seem to work well as augmentations for our training specifications\n", " input_shape = image.shape.as_list()\n", " cropped_size = [input_shape[0] - crop_amount,\n", " input_shape[1] - crop_amount,\n", " input_shape[2]]\n", " flipped = tf.image.random_flip_left_right(image, seed)\n", " cropped = tf.image.random_crop(flipped, cropped_size, seed)\n", " return tf.image.pad_to_bounding_box(image=cropped,\n", " offset_height=crop_amount // 2,\n", " offset_width=crop_amount // 2,\n", " target_height=input_shape[0],\n", " target_width=input_shape[1])\n", "\n", "\n", "def _get_raw_data():\n", " # We split the training data into training and validation ourselves for\n", " # hyperparameter tuning.\n", " training_pct = int(100.0 * TRAINING_SIZE / (TRAINING_SIZE + VALIDATION_SIZE))\n", " train_split = tfds.Split.TRAIN.subsplit(tfds.percent[:training_pct])\n", " validation_split = tfds.Split.TRAIN.subsplit(tfds.percent[training_pct:])\n", "\n", " train_data, info = tfds.load('cifar10:3.*.*', with_info=True, split=train_split)\n", " val_data = tfds.load('cifar10:3.*.*', split=validation_split)\n", " test_data = tfds.load('cifar10:3.*.*', split='test')\n", "\n", " input_shape = info.features['image'].shape\n", " num_classes = info.features['label'].num_classes\n", " info = {'input_shape': input_shape, 'num_classes': num_classes}\n", " return info, train_data, val_data, test_data\n", "\n", "\n", "def get_input_pipeline(batch_size=None,\n", " use_augmentation=True,\n", " seed=None,\n", " crop_amount=6,\n", " drop_remainder=False,\n", " repeat_validation=True):\n", " \"\"\"Creates CIFAR10 Data Pipeline.\n", "\n", " Args:\n", " batch_size (int): Batch size used for training.\n", " use_augmentation (bool): If true, applies random horizontal flips and crops\n", " then pads to images.\n", " seed (int): Random seed used for augmentation operations.\n", " crop_amount (int): Number of pixels to crop from the height and width of the\n", " image. So, the cropped image will be [height - crop_amount, width -\n", " crop_amount, channels] before it is padded to restore its original size.\n", " drop_remainder (bool): Whether to drop the remainder of the batch. Needs to\n", " be true to work on TPUs.\n", " repeat_validation (bool): Whether to repeat the validation set. Test set is\n", " never repeated.\n", "\n", " Returns:\n", " A tuple with an info dict (with input_shape (tuple) and number of classes\n", " (int)) and data dict (train_data (tf.DatasetAdapter), validation_data,\n", " (tf.DatasetAdapter) and test_data (tf.DatasetAdapter))\n", " \"\"\"\n", " info, train_data, val_data, test_data = _get_raw_data()\n", "\n", " if not batch_size:\n", " batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)\n", "\n", " train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()\n", " if use_augmentation:\n", " train_data = train_data.map(\n", " lambda x, y: (_augment_image(x, crop_amount, seed), y))\n", " train_data = train_data.batch(\n", " min(batch_size, TRAINING_SIZE), drop_remainder=drop_remainder)\n", " train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", "\n", " val_data = val_data.map(_parse_fn)\n", " if repeat_validation:\n", " val_data = val_data.repeat()\n", " val_data = val_data.batch(\n", " min(batch_size, VALIDATION_SIZE), drop_remainder=drop_remainder)\n", " val_data = val_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", "\n", " # Don't repeat test data because it is only used once to evaluate at the end.\n", " test_data = test_data.map(_parse_fn)\n", " if repeat_validation:\n", " test_data = test_data.repeat()\n", " test_data = test_data.batch(\n", " min(batch_size, TEST_SIZE), drop_remainder=drop_remainder)\n", " test_data = test_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", "\n", " data = {'train': train_data, 'validation': val_data, 'test': test_data}\n", " return data, info" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "SLvlpsups2aR" }, "source": [ "## Model - Resnet V2\n", "\n", "Based on https://keras.io/examples/cifar10_resnet/. The only difference is that tf.keras layer implementations are used." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Cch3Ld5Ds4i2" }, "outputs": [], "source": [ "def resnet_layer(inputs,\n", " num_filters=16,\n", " kernel_size=3,\n", " strides=1,\n", " activation='relu',\n", " batch_normalization=True,\n", " conv_first=True):\n", " \"\"\"2D Convolution-Batch Normalization-Activation stack builder.\n", "\n", " Based on https://keras.io/examples/cifar10_resnet/.\n", "\n", " Args:\n", " inputs (tensor): input tensor from input image or previous layer\n", " num_filters (int): Conv2D number of filters\n", " kernel_size (int): Conv2D square kernel dimensions\n", " strides (int): Conv2D square stride dimensions\n", " activation (string): activation name\n", " batch_normalization (bool): whether to include batch normalization\n", " conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)\n", "\n", " Returns:\n", " x (tensor): tensor as input to the next layer\n", " \"\"\"\n", " conv = layers.Conv2D(num_filters,\n", " kernel_size=kernel_size,\n", " strides=strides,\n", " padding='same',\n", " kernel_initializer='he_normal',\n", " kernel_regularizer=tf.keras.regularizers.l2(1e-4))\n", "\n", " x = inputs\n", " if conv_first:\n", " x = conv(x)\n", " if batch_normalization:\n", " x = layers.BatchNormalization()(x)\n", " if activation is not None:\n", " x = layers.Activation(activation)(x)\n", " else:\n", " if batch_normalization:\n", " x = layers.BatchNormalization()(x)\n", " if activation is not None:\n", " x = layers.Activation(activation)(x)\n", " x = conv(x)\n", " return x\n", "\n", "\n", "def resnet_v2(input_shape, depth, num_classes=10):\n", " \"\"\"ResNet Version 2 Model builder [b].\n", "\n", " Based on https://keras.io/examples/cifar10_resnet/.\n", "\n", " Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as\n", " bottleneck layer\n", " First shortcut connection per layer is 1 x 1 Conv2D.\n", " Second and onwards shortcut connection is identity.\n", " At the beginning of each stage, the feature map size is halved (downsampled)\n", " by a convolutional layer with strides=2, while the number of filter maps is\n", " doubled. Within each stage, the layers have the same number filters and the\n", " same filter map sizes.\n", " Features maps sizes:\n", " conv1 : 32x32, 16\n", " stage 0: 32x32, 64\n", " stage 1: 16x16, 128\n", " stage 2: 8x8, 256\n", "\n", " Args:\n", " input_shape (tuple/list): shape of input image tensor\n", " depth (int): number of core convolutional layers\n", " num_classes (int): number of classes (CIFAR10 has 10)\n", "\n", " Returns:\n", " model (Model): Keras model instance\n", " \"\"\"\n", " if (depth - 2) % 9 != 0:\n", " raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\n", " # Start model definition.\n", " num_filters_in = 16\n", " num_res_blocks = int((depth - 2) / 9)\n", "\n", " inputs = tf.keras.Input(shape=input_shape)\n", " # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\n", " x = resnet_layer(inputs=inputs, num_filters=num_filters_in, conv_first=True)\n", "\n", " # Instantiate the stack of residual units\n", " for stage in range(3):\n", " for res_block in range(num_res_blocks):\n", " activation = 'relu'\n", " batch_normalization = True\n", " strides = 1\n", " if stage == 0:\n", " num_filters_out = num_filters_in * 4\n", " if res_block == 0: # first layer and first stage\n", " activation = None\n", " batch_normalization = False\n", " else:\n", " num_filters_out = num_filters_in * 2\n", " if res_block == 0: # first layer but not first stage\n", " strides = 2 # downsample\n", "\n", " # bottleneck residual unit\n", " y = resnet_layer(inputs=x,\n", " num_filters=num_filters_in,\n", " kernel_size=1,\n", " strides=strides,\n", " activation=activation,\n", " batch_normalization=batch_normalization,\n", " conv_first=False)\n", " y = resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\n", " y = resnet_layer(inputs=y,\n", " num_filters=num_filters_out,\n", " kernel_size=1,\n", " conv_first=False)\n", " if res_block == 0:\n", " # linear projection residual shortcut connection to match\n", " # changed dims\n", " x = resnet_layer(inputs=x,\n", " num_filters=num_filters_out,\n", " kernel_size=1,\n", " strides=strides,\n", " activation=None,\n", " batch_normalization=False)\n", " x = layers.Add()([x, y])\n", "\n", " num_filters_in = num_filters_out\n", "\n", " # Add classifier on top.\n", " # v2 has BN-ReLU before Pooling\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Activation('relu')(x)\n", " x = layers.AveragePooling2D(pool_size=8)(x)\n", " y = layers.Flatten()(x)\n", " outputs = layers.Dense(num_classes,\n", " activation='softmax',\n", " kernel_initializer='he_normal')(y)\n", "\n", " # Instantiate model.\n", " model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dAUaN-i9tHMY" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Hf5WFHYP8tT9" }, "outputs": [], "source": [ "tf.reset_default_graph()\n", "tf.set_random_seed(SEED)\n", "\n", "data, info = get_input_pipeline(batch_size=batch_size,\n", " seed=SEED,\n", " repeat_validation=True,\n", " use_augmentation=True)\n", "\n", "model = resnet_v2(input_shape=info['input_shape'],\n", " depth=20,\n", " num_classes=info['num_classes'])\n", "\n", "loss = 'sparse_categorical_crossentropy'\n", "\n", "training_callbacks = [\n", " kfac.keras.callbacks.ExponentialDecay(hyperparameter='learning_rate',\n", " init_value=init_learning_rate,\n", " final_value=final_learning_rate,\n", " decay_rate=lr_decay_rate)\n", "]\n", "\n", "if optimizer_name == 'kfac':\n", " opt = kfac.keras.optimizers.Kfac(learning_rate=init_learning_rate,\n", " damping=init_damping,\n", " model=model,\n", " loss=loss,\n", " momentum=momentum,\n", " seed=SEED)\n", " training_callbacks.append(kfac.keras.callbacks.ExponentialDecay(\n", " hyperparameter='damping',\n", " init_value=init_damping,\n", " final_value=final_damping,\n", " decay_rate=damping_decay_rate))\n", "\n", "elif optimizer_name == 'adam':\n", " opt = tf.keras.optimizers.Adam(learning_rate=init_learning_rate,\n", " beta_1=momentum,\n", " epsilon=init_epsilon)\n", " training_callbacks.append(kfac.keras.callbacks.ExponentialDecay(\n", " hyperparameter='epsilon',\n", " init_value=init_epsilon,\n", " final_value=final_epsilon,\n", " decay_rate=epsilon_decay_rate))\n", "\n", "else:\n", " raise ValueError('optimizer_name must be \"adam\" or \"kfac\"')\n", "\n", "model.compile(loss=loss, optimizer=opt, metrics=['acc'])" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "dD8b27hLy6lO" }, "outputs": [], "source": [ "history = model.fit(x=data['train'],\n", " epochs=num_training_steps//steps_per_epoch,\n", " steps_per_epoch=steps_per_epoch,\n", " validation_data=data['validation'],\n", " validation_steps=val_steps,\n", " callbacks=training_callbacks)" ] } ], "metadata": { "colab": { "collapsed_sections": [ "_DDaAex5Q7u-" ], "last_runtime": { "build_target": "", "kind": "local" }, "name": "KFAC vs Adam on CIFAR10.ipynb", "provenance": [ { "file_id": "1pqtoYduODZyJKt4-kwVkt_KtNQCnaNDp", "timestamp": 1565229994386 } ], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 2", "name": "python2" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_DDaAex5Q7u-" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "cellView": "both", "colab": {}, "colab_type": "code", "id": "W1dWWdNHQ9L0" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KDGkOqGA54FB" }, "source": [ "# KFAC vs Adam on CIFAR10 on TPUs\n", "\n", "This notebook demonstrates how to write a custom training loop with TPU Strategy with KFAC. It can be run on a public TPU colab instance. The key differences between using KFAC with TPU Strategy and using TPU Strategy normally are that the model and optimizer must be created in your train step and KFAC does not work with model.fit (because of the first condition). We also use a batch_size of 1024 instead of 1000 to better utilize TPUs.\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorflow/kfac/blob/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "At1AvF75kmlr" }, "outputs": [], "source": [ "!pip install kfac" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "LfGyhnaOsgYu" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import math\n", "import kfac\n", "import os" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "DYWIY0C380ye" }, "outputs": [], "source": [ "TRAINING_SIZE = 40000\n", "VALIDATION_SIZE = 10000\n", "TEST_SIZE = 10000\n", "SEED = 20190524\n", "\n", "num_training_steps = 7500\n", "# We use a batch size of 1024 instead 1000 because each TPU core should\n", "# (ideally) get a batch whose size is a multiple 128 (here we have 8 cores)\n", "batch_size = 1024\n", "layers = tf.keras.layers\n", "\n", "compute_steps_per_epoch = lambda x: int(math.floor(1. * x / batch_size))\n", "steps_per_epoch = compute_steps_per_epoch(TRAINING_SIZE)\n", "val_steps = compute_steps_per_epoch(VALIDATION_SIZE)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "GfeTgsbh5G4g" }, "outputs": [], "source": [ "optimizer_name = 'kfac' # 'kfac' or 'adam'\n", "\n", "# Best Hyperparameters from the Random Search\n", "if optimizer_name == 'kfac':\n", " init_learning_rate = 0.22721400059936694\n", " final_learning_rate = 1e-04\n", " init_damping = 0.28872127217018184\n", " final_damping = 1e-6\n", " momentum = 1 - 0.018580394981260295\n", " lr_decay_rate = 1 - 0.001090107322908028\n", " damping_decay_rate = 1 - 0.0002870880729016523\n", "elif optimizer_name == 'adam':\n", " init_learning_rate = 2.24266320779\n", " final_learning_rate = 1e-4\n", " init_epsilon = 0.183230038808\n", " final_epsilon = 1e-8\n", " momentum = 1 - 0.0296561513388\n", " lr_decay_rate = 1 - 0.000610416031571\n", " epsilon_decay_rate = 1 - 0.000212682338199\n", "else:\n", " raise ValueError('Ensure optimizer_name is kfac or adam')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "v3vSki-usp9k" }, "source": [ "## Input Pipeline\n", "\n", "The tensorflow_datasets CIFAR10 dataset (used in the GPU notebook) does not work with the public TPUs, because of the way the tf.data.Dataset is downloaded. So, this pipeline uses the tf.keras version, which downloads numpy arrays that we turn into a tf.data.Dataset.\n", "\n", "If this pipeline does not work, try using the pipeline in the GPU notebook instead." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "D2U3i5kgssy_" }, "outputs": [], "source": [ "def _parse_fn(image, label):\n", " image = tf.cast(image, tf.float32)\n", " label = tf.cast(tf.squeeze(label), tf.int32)\n", " image = image / 127.5 - 1\n", " return image, label\n", "\n", "\n", "def _augment_image(image, crop_amount, seed=None):\n", " # Random Brightness, Contrast, Jpeg Quality, Hue, and Saturation did not\n", " # seem to work well as augmentations for our training specifications\n", " input_shape = image.shape.as_list()\n", " cropped_size = [input_shape[0] - crop_amount,\n", " input_shape[1] - crop_amount,\n", " input_shape[2]]\n", " flipped = tf.image.random_flip_left_right(image, seed)\n", " cropped = tf.image.random_crop(flipped, cropped_size, seed)\n", " return tf.image.pad_to_bounding_box(image=cropped,\n", " offset_height=crop_amount // 2,\n", " offset_width=crop_amount // 2,\n", " target_height=input_shape[0],\n", " target_width=input_shape[1])\n", "\n", "\n", "def _get_raw_data():\n", " # We split the training data into training and validation ourselves for\n", " # hyperparameter tuning.\n", " train_and_val, test = tf.keras.datasets.cifar10.load_data()\n", " train = (train_and_val[0][:TRAINING_SIZE], train_and_val[1][:TRAINING_SIZE])\n", " val = (train_and_val[0][TRAINING_SIZE:], train_and_val[1][TRAINING_SIZE:])\n", " info = {'input_shape':train_and_val[0].shape[1:], 'num_classes':10}\n", " return (info,\n", " tf.data.Dataset.from_tensor_slices(train),\n", " tf.data.Dataset.from_tensor_slices(val),\n", " tf.data.Dataset.from_tensor_slices(test))\n", "\n", "\n", "def get_input_pipeline(batch_size=None,\n", " use_augmentation=True,\n", " seed=None,\n", " crop_amount=6,\n", " drop_remainder=False,\n", " repeat_validation=True):\n", " \"\"\"Creates CIFAR10 Data Pipeline.\n", "\n", " Args:\n", " batch_size (int): Batch size used for training.\n", " use_augmentation (bool): If true, applies random horizontal flips and crops\n", " then pads to images.\n", " seed (int): Random seed used for augmentation operations.\n", " crop_amount (int): Number of pixels to crop from the height and width of the\n", " image. So, the cropped image will be [height - crop_amount, width -\n", " crop_amount, channels] before it is padded to restore its original size.\n", " drop_remainder (bool): Whether to drop the remainder of the batch. Needs to\n", " be true to work on TPUs.\n", " repeat_validation (bool): Whether to repeat the validation set. Test set is\n", " never repeated.\n", "\n", " Returns:\n", " A tuple with an info dict (with input_shape (tuple) and number of classes\n", " (int)) and data dict (train_data (tf.DatasetAdapter), validation_data,\n", " (tf.DatasetAdapter) and test_data (tf.DatasetAdapter))\n", " \"\"\"\n", " info, train_data, val_data, test_data = _get_raw_data()\n", "\n", " if not batch_size:\n", " batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)\n", "\n", " train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()\n", " if use_augmentation:\n", " train_data = train_data.map(\n", " lambda x, y: (_augment_image(x, crop_amount, seed), y))\n", " train_data = train_data.batch(\n", " min(batch_size, TRAINING_SIZE), drop_remainder=drop_remainder)\n", " train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", "\n", " val_data = val_data.map(_parse_fn)\n", " if repeat_validation:\n", " val_data = val_data.repeat()\n", " val_data = val_data.batch(\n", " min(batch_size, VALIDATION_SIZE), drop_remainder=drop_remainder)\n", " val_data = val_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", "\n", " # Don't repeat test data because it is only used once to evaluate at the end.\n", " test_data = test_data.map(_parse_fn)\n", " if repeat_validation:\n", " test_data = test_data.repeat()\n", " test_data = test_data.batch(\n", " min(batch_size, TEST_SIZE), drop_remainder=drop_remainder)\n", " test_data = test_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", "\n", " data = {'train': train_data, 'validation': val_data, 'test': test_data}\n", " return data, info" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "SLvlpsups2aR" }, "source": [ "## Model - Resnet V2\n", "\n", "Based on https://keras.io/examples/cifar10_resnet/. The only difference is that tf.keras layer implementations are used, the model outputs logits instead of a probability distribution, and an input tensor is used instead of the input shape (since TPUs don't support placeholders)." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Cch3Ld5Ds4i2" }, "outputs": [], "source": [ "def resnet_layer(inputs,\n", " num_filters=16,\n", " kernel_size=3,\n", " strides=1,\n", " activation='relu',\n", " batch_normalization=True,\n", " conv_first=True):\n", " \"\"\"2D Convolution-Batch Normalization-Activation stack builder.\n", "\n", " Based on https://keras.io/examples/cifar10_resnet/.\n", "\n", " Args:\n", " inputs (tensor): input tensor from input image or previous layer\n", " num_filters (int): Conv2D number of filters\n", " kernel_size (int): Conv2D square kernel dimensions\n", " strides (int): Conv2D square stride dimensions\n", " activation (string): activation name\n", " batch_normalization (bool): whether to include batch normalization\n", " conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)\n", "\n", " Returns:\n", " x (tensor): tensor as input to the next layer\n", " \"\"\"\n", " conv = layers.Conv2D(num_filters,\n", " kernel_size=kernel_size,\n", " strides=strides,\n", " padding='same',\n", " kernel_initializer='he_normal',\n", " kernel_regularizer=tf.keras.regularizers.l2(1e-4))\n", "\n", " x = inputs\n", " if conv_first:\n", " x = conv(x)\n", " if batch_normalization:\n", " x = layers.BatchNormalization()(x)\n", " if activation is not None:\n", " x = layers.Activation(activation)(x)\n", " else:\n", " if batch_normalization:\n", " x = layers.BatchNormalization()(x)\n", " if activation is not None:\n", " x = layers.Activation(activation)(x)\n", " x = conv(x)\n", " return x\n", "\n", "\n", "def resnet_v2(input_tensor, depth, num_classes=10):\n", " \"\"\"ResNet Version 2 Model builder [b].\n", "\n", " Based on https://keras.io/examples/cifar10_resnet/.\n", "\n", " Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as\n", " bottleneck layer\n", " First shortcut connection per layer is 1 x 1 Conv2D.\n", " Second and onwards shortcut connection is identity.\n", " At the beginning of each stage, the feature map size is halved (downsampled)\n", " by a convolutional layer with strides=2, while the number of filter maps is\n", " doubled. Within each stage, the layers have the same number filters and the\n", " same filter map sizes.\n", " Features maps sizes:\n", " conv1 : 32x32, 16\n", " stage 0: 32x32, 64\n", " stage 1: 16x16, 128\n", " stage 2: 8x8, 256\n", "\n", " Args:\n", " input_shape (tuple/list): shape of input image tensor\n", " depth (int): number of core convolutional layers\n", " num_classes (int): number of classes (CIFAR10 has 10)\n", "\n", " Returns:\n", " model (Model): Keras model instance\n", " \"\"\"\n", " if (depth - 2) % 9 != 0:\n", " raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\n", " # Start model definition.\n", " num_filters_in = 16\n", " num_res_blocks = int((depth - 2) / 9)\n", "\n", " inputs = tf.keras.Input(tensor=input_tensor)\n", " # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\n", " x = resnet_layer(inputs=inputs, num_filters=num_filters_in, conv_first=True)\n", "\n", " # Instantiate the stack of residual units\n", " for stage in range(3):\n", " for res_block in range(num_res_blocks):\n", " activation = 'relu'\n", " batch_normalization = True\n", " strides = 1\n", " if stage == 0:\n", " num_filters_out = num_filters_in * 4\n", " if res_block == 0: # first layer and first stage\n", " activation = None\n", " batch_normalization = False\n", " else:\n", " num_filters_out = num_filters_in * 2\n", " if res_block == 0: # first layer but not first stage\n", " strides = 2 # downsample\n", "\n", " # bottleneck residual unit\n", " y = resnet_layer(inputs=x,\n", " num_filters=num_filters_in,\n", " kernel_size=1,\n", " strides=strides,\n", " activation=activation,\n", " batch_normalization=batch_normalization,\n", " conv_first=False)\n", " y = resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\n", " y = resnet_layer(inputs=y,\n", " num_filters=num_filters_out,\n", " kernel_size=1,\n", " conv_first=False)\n", " if res_block == 0:\n", " # linear projection residual shortcut connection to match\n", " # changed dims\n", " x = resnet_layer(inputs=x,\n", " num_filters=num_filters_out,\n", " kernel_size=1,\n", " strides=strides,\n", " activation=None,\n", " batch_normalization=False)\n", " x = layers.Add()([x, y])\n", "\n", " num_filters_in = num_filters_out\n", "\n", " # Add classifier on top.\n", " # v2 has BN-ReLU before Pooling\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Activation('relu')(x)\n", " x = layers.AveragePooling2D(pool_size=8)(x)\n", " y = layers.Flatten()(x)\n", " outputs = layers.Dense(num_classes,\n", " activation='softmax',\n", " kernel_initializer='he_normal')(y)\n", "\n", " # Instantiate model.\n", " model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hpSh8fWKiWO7" }, "source": [ "## TPU Set Up" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "c_JSG9X7iaxu" }, "outputs": [], "source": [ "def get_tpu_address():\n", " if 'TPU_NAME' in os.environ and 'COLAB_TPU_ADDR' in os.environ: # public colab\n", " assert os.environ['COLAB_GPU'] == '0'\n", " TPU_ADDRESS = os.environ['TPU_NAME']\n", " from google.colab import auth\n", " auth.authenticate_user()\n", " print('Running on public colab https://colab.research.google.com')\n", " elif 'TPU_NAME' in os.environ and not 'COLAB_TPU_ADDR' in os.environ: # Cloud TPU\n", " TPU_ADDRESS = os.environ['TPU_NAME']\n", " print('Running on Cloud TPU')\n", " else:\n", " raise ValueError('Unknown environment')\n", " return TPU_ADDRESS" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "5LaK9He6Yw4B" }, "outputs": [], "source": [ "cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(\n", " tpu=get_tpu_address())\n", "tf.tpu.experimental.initialize_tpu_system(cluster_resolver)\n", "tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dAUaN-i9tHMY" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "ClkFCVbDp0HK" }, "outputs": [], "source": [ "def get_optimizer(model, loss, global_step):\n", " decayed_learning_rate = tf.train.exponential_decay(init_learning_rate,\n", " global_step=global_step,\n", " decay_rate=lr_decay_rate,\n", " decay_steps=1)\n", " learning_rate = tf.maximum(decayed_learning_rate, final_learning_rate)\n", "\n", " if optimizer_name == 'kfac':\n", " decayed_damping = tf.train.exponential_decay(init_damping,\n", " global_step=global_step,\n", " decay_rate=damping_decay_rate,\n", " decay_steps=1)\n", " damping = tf.maximum(decayed_damping, final_damping)\n", " # We cannot use the Keras version because Keras optimizers do not support\n", " # a global_step argument for minimize. Instead, we use the Keras automated\n", " # layed collection functionality to get our layer collection.\n", " lc = kfac.keras.utils.get_layer_collection(\n", " model=model, loss=loss, seed=SEED)\n", " optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n", " learning_rate=learning_rate,\n", " damping=damping,\n", " momentum=momentum,\n", " layer_collection=lc,\n", " # Replica round robin places each inverse operations on a different \n", " # replica (TPU core) so that each inverse is computed on one replica\n", " # then the replicas are synced.\n", " placement_strategy='replica_round_robin')\n", "\n", " elif optimizer_name == 'adam':\n", " decayed_epsilon = tf.train.exponential_decay(init_epsilon,\n", " global_step=global_step,\n", " decay_rate=epsilon_decay_rate,\n", " decay_steps=1)\n", " epsilon = tf.maximum(decayed_epsilon, final_epsilon)\n", " optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,\n", " beta1=momentum,\n", " epsilon=epsilon)\n", " return optimizer" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "QuFSOAd-irxw" }, "outputs": [], "source": [ "def train_step_fn(info, loss_metric, accuracy_metric):\n", " # We create the model in the train step, but we want to return a reference to\n", " # it so it can be used for validation. We return a reference to the model list\n", " # which will be populated after the train_step is run.\n", " model_list = []\n", " def train_step(inputs):\n", " # Need this for layer collection to work correctly. Also, by setting this\n", " # to 1, batchnorm statistics are computed in this pass.\n", " tf.keras.backend.set_learning_phase(1)\n", "\n", " img, labels = inputs\n", "\n", " # The model needs to be created in the train step for KFAC's layer\n", " # collection. TPU Strategy autographs this function, so if the model is\n", " # constructed outside the train step, KFAC's layer collection will capture\n", " # the wrong input/output tensors.\n", " # Since TPUs do not support placeholders, we must construct our model\n", " # directly with the input tensor.\n", " model = resnet_v2(input_tensor=img,\n", " depth=20,\n", " num_classes=info['num_classes'])\n", " model_list.append(model)\n", "\n", " # Since we constructed our model with the input tensor, the model.output\n", " # is equivalent to model(img). In a non TPU custom training loop, you can\n", " # use model(img) instead.\n", " logits = model.output\n", " cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(\n", " labels=labels, logits=logits)\n", " regularization_loss = tf.reduce_sum(model.losses)\n", " cross_entropy_loss = tf.reduce_mean(cross_entropy)\n", " # When using Distribution Strategy with KFAC, you must NOT scale the loss.\n", " loss = regularization_loss + cross_entropy_loss\n", "\n", " update_loss = loss_metric.update_state(loss)\n", " update_accuracy = accuracy_metric.update_state(y_true=labels, y_pred=logits)\n", "\n", " global_step = tf.train.get_or_create_global_step()\n", "\n", " optimizer = get_optimizer(model=model,\n", " loss='sparse_categorical_crossentropy',\n", " global_step=global_step)\n", "\n", " train_op = optimizer.minimize(loss,\n", " var_list=model.trainable_weights,\n", " global_step=global_step)\n", "\n", " # Control dependencies ensures updates are run before the loss is returned\n", " with tf.control_dependencies([train_op, update_loss, update_accuracy]):\n", " return tf.identity(loss)\n", "\n", " return train_step, model_list" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "_FfEPQAvrTA1" }, "outputs": [], "source": [ "def eval_step_fn(model, loss_metric, accuracy_metric):\n", " \"\"\"For validation or test.\"\"\"\n", "\n", " def eval_step(inputs):\n", " tf.keras.backend.set_learning_phase(0)\n", "\n", " img, labels = inputs\n", " logits = model(img, training=False)\n", " cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(\n", " labels=labels, logits=logits)\n", " regularization_loss = tf.reduce_sum(model.losses)\n", " cross_entropy_loss = tf.reduce_mean(cross_entropy)\n", " loss = regularization_loss + cross_entropy_loss\n", "\n", " update_loss = loss_metric.update_state(loss)\n", " update_accuracy = accuracy_metric.update_state(y_true=labels, y_pred=logits)\n", "\n", " with tf.control_dependencies([update_loss, update_accuracy]):\n", " return tf.identity(loss)\n", "\n", " return eval_step" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Hf5WFHYP8tT9" }, "outputs": [], "source": [ "tf.reset_default_graph()\n", "\n", "with tpu_strategy.scope():\n", " data, info = get_input_pipeline(batch_size=batch_size,\n", " seed=SEED,\n", " drop_remainder=True,\n", " repeat_validation=False)\n", "\n", " train_iterator = tpu_strategy.make_dataset_iterator(data['train'])\n", " val_iterator = tpu_strategy.make_dataset_iterator(data['validation'])\n", "\n", " train_loss_metric = tf.keras.metrics.Mean(\n", " 'training_loss', dtype=tf.float32)\n", " train_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(\n", " 'training_accuracy', dtype=tf.float32)\n", " val_loss_metric = tf.keras.metrics.Mean(\n", " 'val_loss', dtype=tf.float32)\n", " val_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(\n", " 'val_accuracy', dtype=tf.float32)\n", "\n", " train_step, model_list = train_step_fn(\n", " info, train_loss_metric, train_accuracy_metric)\n", " # experimental_local_results gives us a list of the loss values from each\n", " # replica. Since we're tracking loss via the Keras Metric, we don't need to\n", " # worry about reporting (or reducing) this value. If we were to record this\n", " # value, we should do a mean across replicas since each replica will return an\n", " # unscaled loss and each replica has the same batch size.\n", " train_step_op = tpu_strategy.experimental_local_results(\n", " tpu_strategy.experimental_run(train_step, train_iterator))\n", "\n", " model = model_list[0] # There will only be one model in the list.\n", " val_step = eval_step_fn(model, val_loss_metric, val_accuracy_metric)\n", " val_step_op = tpu_strategy.experimental_local_results(\n", " tpu_strategy.experimental_run(val_step, val_iterator))\n", "\n", " all_variables = (\n", " tf.global_variables() +\n", " train_loss_metric.variables + train_accuracy_metric.variables +\n", " val_loss_metric.variables + val_accuracy_metric.variables\n", " )" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "D2w08M2dtwyL" }, "outputs": [], "source": [ "# Without this config, TensorFlow will attempt to place two connected ops on\n", "# different devices, which will cause an InvalidArgumentError.\n", "config = tf.ConfigProto()\n", "config.allow_soft_placement = True\n", "cluster_spec = cluster_resolver.cluster_spec()\n", "if cluster_spec:\n", " config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())\n", "\n", "with tf.Session(cluster_resolver.master(), config=config) as session:\n", " session.run([v.initializer for v in all_variables])\n", " session.run(train_iterator.initializer)\n", " print('Starting training...')\n", " for step in range(num_training_steps):\n", " session.run(train_step_op)\n", "\n", " if step % steps_per_epoch == 0:\n", " session.run(val_iterator.initializer)\n", " for _ in range(val_steps):\n", " session.run(val_step_op)\n", "\n", " print('================ Step {} ================'.format(step))\n", " # The printed train loss is the mean over the entire epoch.\n", " print('Train Loss {}'.format(session.run(train_loss_metric.result())))\n", " print('Train Accuracy {}'.format(\n", " session.run(train_accuracy_metric.result())))\n", " print('Val Loss {}'.format(session.run(val_loss_metric.result())))\n", " print('Val Accuracy {}'.format(\n", " session.run(val_accuracy_metric.result())))\n", " train_loss_metric.reset_states()\n", " train_accuracy_metric.reset_states()\n", " val_loss_metric.reset_states()\n", " val_accuracy_metric.reset_states()\n", "\n", " print('Done training')" ] } ], "metadata": { "colab": { "collapsed_sections": [ "_DDaAex5Q7u-", "v3vSki-usp9k", "SLvlpsups2aR" ], "last_runtime": { "build_target": "", "kind": "local" }, "name": "KFAC vs Adam on CIFAR10 - TPU.ipynb", "provenance": [ { "file_id": "1GOgzfQLpg5aoq_uajqcqLqFTY0ohduEr", "timestamp": 1565229974969 }, { "file_id": "1pqtoYduODZyJKt4-kwVkt_KtNQCnaNDp", "timestamp": 1565044838251 } ], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 2", "name": "python2" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: kfac/examples/mnist.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for loading MNIST into TensorFlow.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np import tensorflow.compat.v1 as tf __all__ = [ 'load_mnist_as_tensors', 'load_mnist_as_dataset', 'load_mnist_as_iterator', ] def load_mnist_as_tensors(flatten_images=True, dtype=tf.float32): """Loads MNIST as Tensors. Args: flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into [784]-shaped vectors. dtype: The TF dtype to return the images as. Returns: images, labels, num_examples """ # mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets( # '/tmp/mnist', reshape=flatten_images) # num_examples = len(mnist_data.train.labels) # images = mnist_data.train.images # labels = mnist_data.train.labels # # images = tf.constant(np.asarray(images, dtype=np.float32)) # labels = tf.constant(np.asarray(labels, dtype=np.int64)) # # return images, labels, num_examples (images, labels), _ = tf.keras.datasets.mnist.load_data() num_examples = images.shape[0] if flatten_images: images = images.reshape(images.shape[0], 28**2) else: images = images.reshape(images.shape[0], 28, 28, 1) images = images.astype('float64') labels = labels.astype('int32') images /= 255. images = tf.constant(images, dtype=dtype) labels = tf.constant(labels) return images, labels, num_examples def load_mnist_as_dataset(flatten_images=True): """Loads MNIST as a Dataset object. Args: flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into [784]-shaped vectors. Returns: dataset, num_examples, where dataset is a Dataset object containing the whole MNIST training dataset and num_examples is the number of examples in the MNIST dataset (should be 60000). """ images, labels, num_examples = load_mnist_as_tensors( flatten_images=flatten_images) dataset = tf.data.Dataset.from_tensor_slices((images, labels)) return dataset, num_examples def load_mnist_as_iterator(num_epochs, batch_size, use_fake_data=False, flatten_images=True): """Loads MNIST dataset as an iterator Tensor. Args: num_epochs: int. Number of passes to make over the dataset. batch_size: int. Number of examples per minibatch. use_fake_data: bool. If True, generate a synthetic dataset rather than reading MNIST in. flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into [784]-shaped vectors. Returns: examples: Tensor of shape [batch_size, 784] if 'flatten_images' is True, else [batch_size, 28, 28, 1]. Each row is one example. Values in [0, 1]. labels: Tensor of shape [batch_size]. Indices of integer corresponding to each example. Values in {0...9}. """ if use_fake_data: rng = np.random.RandomState(42) num_examples = batch_size * 4 images = rng.rand(num_examples, 28 * 28) if not flatten_images: images = np.reshape(images, [num_examples, 28, 28, 1]) labels = rng.randint(10, size=num_examples) dataset = tf.data.Dataset.from_tensor_slices((np.asarray( images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) else: dataset, num_examples = load_mnist_as_dataset(flatten_images=flatten_images) dataset = (dataset.shuffle(num_examples).repeat(num_epochs) .batch(batch_size).prefetch(5)) return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() ================================================ FILE: kfac/examples/rnn_mnist.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """RNN trained to do sequential MNIST classification using K-FAC. This demonstrates the use of the RNN approximations from the paper "Kronecker-factored Curvature Approximations for Recurrent Neural Networks". The setup here is similar to the autoencoder example. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math # Dependency imports from absl import flags import kfac import tensorflow.compat.v1 as tf from kfac.examples import mnist from kfac.python.ops.kfac_utils import data_reader from kfac.python.ops.kfac_utils import data_reader_alt # We need this for now since linear layers without biases don't work with # automatic scanning at the moment _INCLUDE_INPUT_BIAS = True flags.DEFINE_string('kfac_approx', 'kron_indep', 'The type of approximation to use for the recurrent ' 'layers. "kron_indep" is the one which assumes ' 'independence across time, "kron_series_1" is "Option 1" ' 'from the paper, and "kron_series_2" is "Option 2".') flags.DEFINE_integer('inverse_update_period', 5, '# of steps between computing inverse of Fisher factor ' 'matrices.') flags.DEFINE_integer('cov_update_period', 1, '# of steps between computing covaraiance matrices.') flags.DEFINE_integer('damping_adaptation_interval', 5, '# of steps between updating the damping parameter.') flags.DEFINE_float('learning_rate', 3e-4, 'Learning rate to use when adaptation="off".') flags.DEFINE_float('momentum', 0.9, 'Momentum decay value to use when ' 'lrmu_adaptation="off" or "only_lr".') flags.DEFINE_boolean('use_batch_size_schedule', True, 'If True then we use the growing mini-batch schedule from ' 'the original K-FAC paper.') flags.DEFINE_integer('batch_size', 1024, 'The size of the mini-batches to use if not using the ' 'schedule.') flags.DEFINE_string('lrmu_adaptation', 'on', 'If set to "on" then we use the quadratic model ' 'based learning-rate and momentum adaptation method from ' 'the original paper. Note that this only works well in ' 'practice when use_batch_size_schedule=True. Can also ' 'be set to "off" and "only_lr", which turns ' 'it off, or uses a version where the momentum parameter ' 'is fixed (resp.).') flags.DEFINE_boolean('use_alt_data_reader', True, 'If True we use the alternative data reader for MNIST ' 'that is faster for small datasets.') flags.DEFINE_integer('num_hidden', 128, 'Hidden state dimension of the RNN.') flags.DEFINE_boolean('use_auto_registration', False, 'Whether to use the automatic registration feature.') flags.DEFINE_string('device', '/gpu:0', 'The device to run the major ops on.') FLAGS = flags.FLAGS def make_train_op(batch_size, batch_loss, layer_collection, loss_fn, cached_reader): """Constructs optimizer and train op. Args: batch_size: Tensor of shape (), Size of the training batch. batch_loss: Tensor of shape (), Loss with respect to minibatch to be minimzed. layer_collection: LayerCollection or None. Registry for model parameters. Required when using a K-FAC optimizer. loss_fn: Function which takes as input training data and returns loss. cached_reader: `data_reader.CachedReader` instance. Returns: train_op: Op that can be used to update model parameters. optimizer: Optimizer used to produce train_op. Raises: ValueError: If layer_collection is None when K-FAC is selected as an optimization method. """ global_step = tf.train.get_or_create_global_step() if layer_collection is None: raise ValueError('layer_collection must be defined to use K-FAC.') if FLAGS.lrmu_adaptation == 'on': learning_rate = None momentum = None momentum_type = 'qmodel' elif FLAGS.lrmu_adaptation == 'only_lr': learning_rate = None momentum = FLAGS.momentum momentum_type = 'qmodel_fixedmu' elif FLAGS.lrmu_adaptation == 'off': learning_rate = FLAGS.learning_rate momentum = FLAGS.momentum # momentum_type = 'regular' momentum_type = 'adam' optimizer = kfac.PeriodicInvCovUpdateKfacOpt( invert_every=FLAGS.inverse_update_period, cov_update_every=FLAGS.cov_update_period, learning_rate=learning_rate, damping=150., # When using damping adaptation it is advisable to start # with a high value. This value is probably far too high # to use for most neural nets if you aren't using damping # adaptation. (Although it always depends on the scale of # the loss.) cov_ema_decay=0.95, momentum=momentum, momentum_type=momentum_type, layer_collection=layer_collection, batch_size=batch_size, num_burnin_steps=5, adapt_damping=True, is_chief=True, prev_train_batch=cached_reader.cached_batch, loss=batch_loss, loss_fn=loss_fn, damping_adaptation_decay=0.95, damping_adaptation_interval=FLAGS.damping_adaptation_interval, min_damping=1e-5 ) return optimizer.minimize(batch_loss, global_step=global_step), optimizer def eval_model(x, num_classes, layer_collection=None): """Evaluate the model given the data and possibly register it.""" num_hidden = FLAGS.num_hidden num_timesteps = x.shape[1] num_input = x.shape[2] # Strip off the annoying last dimension of size 1 (added for convenient use # with conv nets). x = x[..., 0] # Unstack to get a list of 'num_timesteps' tensors of # shape (batch_size, num_input) x_unstack = tf.unstack(x, num_timesteps, 1) # We need to do this manually without cells since we need to get access # to the pre-activations (i.e. the output of the "linear layers"). w_in = tf.get_variable('w_in', shape=[num_input, num_hidden]) if _INCLUDE_INPUT_BIAS: b_in = tf.get_variable('b_in', shape=[num_hidden]) w_rec = tf.get_variable('w_rec', shape=[num_hidden, num_hidden]) b_rec = tf.get_variable('b_rec', shape=[num_hidden]) a = tf.zeros([tf.shape(x_unstack[0])[0], num_hidden], dtype=tf.float32) # Here 'a' are the activations, 's' the pre-activations a_list = [] s_in_list = [] s_rec_list = [] s_list = [] for input_ in x_unstack: a_list.append(a) s_in = tf.matmul(input_, w_in) if _INCLUDE_INPUT_BIAS: s_in += b_in s_rec = tf.matmul(a, w_rec) + b_rec # s_rec = b_rec + tf.matmul(a, w_rec) # this breaks the graph scanner s = s_in + s_rec s_in_list.append(s_in) s_rec_list.append(s_rec) s_list.append(s) a = tf.tanh(s) final_rnn_output = a # NOTE: we can uncomment the lines below without changing how the algorithm # behaves. This is because the derivative of the loss w.r.t. to s is the # the same as it is for both s_in and s_rec. This can be seen easily from # the chain rule. # # s_rec_list = s_list # s_in_list = s_list if _INCLUDE_INPUT_BIAS: pin = (w_in, b_in) else: pin = w_in if layer_collection: layer_collection.register_fully_connected_multi(pin, x_unstack, s_in_list, approx=FLAGS.kfac_approx) layer_collection.register_fully_connected_multi((w_rec, b_rec), a_list, s_rec_list, approx=FLAGS.kfac_approx) # Output parameters (need this no matter how we construct the RNN): w_out = tf.get_variable('w_out', shape=[num_hidden, num_classes]) b_out = tf.get_variable('b_out', shape=[num_classes]) logits = tf.matmul(final_rnn_output, w_out) + b_out if layer_collection: layer_collection.register_fully_connected((w_out, b_out), final_rnn_output, logits) return logits def compute_loss(inputs, labels, num_classes, layer_collection=None): """Compute loss value.""" with tf.variable_scope('model', reuse=tf.AUTO_REUSE): if FLAGS.use_auto_registration: logits = eval_model(inputs, num_classes) else: logits = eval_model(inputs, num_classes, layer_collection=layer_collection) losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) loss = tf.reduce_mean(losses) if layer_collection is not None: layer_collection.register_softmax_cross_entropy_loss(logits) if FLAGS.use_auto_registration: layer_collection.auto_register_layers() return loss def load_mnist(): """Creates MNIST dataset and wraps it inside cached data reader. Returns: cached_reader: `data_reader.CachedReader` instance which wraps MNIST dataset. num_examples: int. The number of training examples. """ # Wrap the data set into cached_reader which provides variable sized training # and caches the read train batch. if not FLAGS.use_alt_data_reader: # Version 1 using data_reader.py (slow!) dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=False) if FLAGS.use_batch_size_schedule: max_batch_size = num_examples else: max_batch_size = FLAGS.batch_size # Shuffle before repeat is correct unless you want repeat cases in the # same batch. dataset = (dataset.shuffle(num_examples).repeat() .batch(max_batch_size).prefetch(5)) dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() # This version of CachedDataReader requires the dataset to be shuffled return data_reader.CachedDataReader(dataset, max_batch_size), num_examples else: # Version 2 using data_reader_alt.py (faster) images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=False) dataset = (images, labels) # This version of CachedDataReader requires the dataset to NOT be shuffled return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples def main(_): # Load dataset. cached_reader, num_examples = load_mnist() num_classes = 10 minibatch_maxsize_targetiter = 500 minibatch_maxsize = num_examples minibatch_startsize = 1000 div = (float(minibatch_maxsize_targetiter-1) / math.log(float(minibatch_maxsize)/minibatch_startsize, 2)) batch_size_schedule = [ min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize) for k in range(500) ] batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') layer_collection = kfac.LayerCollection() def loss_fn(minibatch, layer_collection=None): return compute_loss(minibatch[0], minibatch[1], num_classes, layer_collection=layer_collection) minibatch = cached_reader(batch_size) batch_loss = loss_fn(minibatch, layer_collection=layer_collection) # Make training op with tf.device(FLAGS.device): train_op, opt = make_train_op( batch_size, batch_loss, layer_collection, loss_fn=loss_fn, cached_reader=cached_reader) learning_rate = opt.learning_rate momentum = opt.momentum damping = opt.damping rho = opt.rho qmodel_change = opt.qmodel_change global_step = tf.train.get_or_create_global_step() # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # Train model. with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30, config=config) as sess: while not sess.should_stop(): i = sess.run(global_step) if FLAGS.use_batch_size_schedule: batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)] else: batch_size_ = FLAGS.batch_size _, batch_loss_ = sess.run([train_op, batch_loss], feed_dict={batch_size: batch_size_}) # We get these things in a separate sess.run() call because they are # stored as variables in the optimizer. (So there is no computational cost # to getting them, and if we don't get them after the previous call is # over they might not be updated.) (learning_rate_, momentum_, damping_, rho_, qmodel_change_) = sess.run([learning_rate, momentum, damping, rho, qmodel_change]) # Print training stats. tf.logging.info( 'iteration: %d', i) tf.logging.info( 'mini-batch size: %d | mini-batch loss = %f', batch_size_, batch_loss_) tf.logging.info( 'learning_rate = %f | momentum = %f', learning_rate_, momentum_) tf.logging.info( 'damping = %f | rho = %f | qmodel_change = %f', damping_, rho_, qmodel_change_) tf.logging.info('----') if __name__ == '__main__': tf.disable_v2_behavior() tf.app.run(main) ================================================ FILE: kfac/python/__init__.py ================================================ ================================================ FILE: kfac/python/keras/README.md ================================================ # K-FAC for Keras **K-FAC for Keras** is an implementation of K-FAC, an approximate second-order optimization method, in TensorFlow. You can read more about it in the paper [here][paper] and the GitHub docs [here][index]. [index]: https://github.com/tensorflow/kfac/tree/master/docs/index.md [paper]: https://arxiv.org/abs/1503.05671 ## Why should I use K-FAC for Keras? In addition to the reasons outlined on the GitHub docs, the Keras version handles layer and loss registration automatically and works with Keras's convenient training API. See the reference code [here][cifar10]. [cifar10]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb [cifar10tpu]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb ## How do I use K-FAC for Keras? Using this optimizer is almost the same as using any other Keras optimizer, except you must also pass the loss and model to the optimizer. The optimizer will automatically register the model layers and loss so K-FAC can compute the fisher approximations. ```python import tensorflow.compat.v1 as tf import kfac # Build Keras Model (can use functional or sequential) model = tf.keras.Model(...) loss = 'sparse_categorical_crossentropy' # or a tf.keras.losses.* instance # Construct Optimizer optimizer = kfac.keras.optimizers.Kfac(learning_rate=0.001, damping=0.01, model=model, loss=loss) # Compile and Fit Model model.compile(optimizer=optimizer, loss=loss, ...) model.fit(...) ``` Check out our CIFAR-10 CNN training [example][cifar10] and [TPU Strategy example][cifar10tpu] for more details. This optimizer currently supports the following tf.keras.layers types: Conv2D, Conv1D, Dense, BatchNormalization, LayerNormalization and Embedding. The following tf.keras.losses are supported: sparse_categorical_crossentropy, categorical_crossentropy, binary_crossentropy, and mean_squared_error. You may use any architecture with these basic layers and losses, including multiple branches and loss functions. To use an unsupported layer or loss, you can register layers manually using a LayerCollection object and pass that to the optimizer constructor. Examples of using LayerCollection are [here][layercollection]. [layercollection]: https://github.com/tensorflow/kfac/tree/master/kfac/examples ## How is K-FAC Different from Other Keras Optimizers? 1. When using your model as a callable (i.e. `output = model(input)`), `input` must be a Keras layer. If it is a normal tensor, you can wrap it as follows: `new_input = tf.keras.layers.Input(tensor=input)`. This is so Keras registers the layer as an inbound_node during the call, allowing our layer collection to register it correctly. By default, our automatic layer collection will register only the latest use of the model. 2. Only a subset of the hyperparameters can be accessed and modified after instantiation. These are: learning_rate, damping, momentum, weight_decay_coeff, norm_constraint, and batch_size. These hyperparameters will work the same as normal hyperparameters in native Keras optimizers and can be used with tools like hyperparameter scheduler callbacks. You can see exactly which hyperparameters are modifiable by checking the `optimizer.mutable_hyperparameters` property. Note that damping cannot be modified when using adaptive damping, and momentum/learning_rate cannot be modified when using qmodel momentum. Also, if any of the hyperparameters are `None` during instantiation, they will not be modifiable during training. 3. This optimizer is tested with TPUStrategy and MirroredStrategy. However, you may not use a Strategy with model.fit for two reasons. First, we expect an unscaled loss (i.e. it should NOT be scaled by 1.0 / global_batch_size). Second, TPUStrategy will autograph the train step, so your model and optimizer must both be created in the train step for KFAC to work. This is not possible with model.fit. See our [CIFAR10 TPU][cifar10tpu] example for details on how to do this. 4. This optimizer is fully compatible with tf.keras.models.save_model or model.save(). To load the compiled model with the optimizer, you must use our saving_utils.load_model method, which is identical to tf.keras.models.load_model except it registers the model with the optimizer after compiling the model and before loading the optimizer's weights. Example: ```python import tensorflow as tf import kfac model = tf.keras.Model(...) loss = tf.keras.losses.MSE() # could be a serialized loss function optimizer = kfac.keras.optimizers.Kfac(learning_rate=0.001, damping=0.01, model=model, loss=loss) model.compile(optimizer, loss) model.fit(...) model.save('saved_model.hdf5') # or tf.keras.models.save_model(model) ... loaded_model = kfac.keras.saving_utils.load_model('saved_model.hdf5') loaded_model.fit(...) ``` ## EXPERIMENTAL - How can I use the adaptive damping/momentum/learning rate? The original [KFAC paper][paper] outlines how the optimizer can automatically adjust the learning rate, momentum, and damping. You can use it as follows: ```python import tensorflow.compat.v1 as tf from tensorflow_kfac.keras import kfac_optimizer # tf.data.Dataset dataset dataset = ... dataset = dataset.shuffle(...).repeat().batch(..., drop_remainder=True) train_batch = train_batch.get_one_shot_iterator().get_next() # (x, y) tensors model = tf.keras.Model(...) loss = 'sparse_categorical_crossentropy' # Construct Optimizer optimizer = kfac.keras.optimizers..Kfac(damping=10.0, adaptive=True, model=model, loss=loss, train_batch=train_batch, ...) # Compile and Fit Model model.compile(optimizer=optimizer, loss=loss, ...) model.fit(train_batch, ...) ``` If your batch size is not fixed at the start of training (i.e. it has an ? dimension, such as when `drop_remainder=False`), you must pass the `batch_size` in the constructor. If you do not use `optimizer.minimize(...)`, you must pass in the `loss_tensor`. If you use a custom loss function, you must pass in the `loss_fn` in the constructor. Look at the documentation for the TensorFlow KFAC optimizer for details on how to customize this more. Note that this feature is experimental, so it is not recommended for standard use cases. It works best when used with a high initial damping (10.0-100.0), and with a large batch size. The [autoencoder example][ae_eg] shows using the adaptive damping and qmodel momentum successfully. [ae_eg]: https://github.com/tensorflow/kfac/blob/master/kfac/examples/autoencoder_mnist.py ================================================ FILE: kfac/python/keras/__init__.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """KFAC for Keras.""" from kfac.python.keras import callbacks from kfac.python.keras import optimizers from kfac.python.keras import saving_utils from kfac.python.keras import utils ================================================ FILE: kfac/python/keras/callbacks.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Hyperparameter Scheduling Callbacks for Keras K-FAC.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import six import tensorflow.compat.v1 as tf @six.add_metaclass(abc.ABCMeta) class HyperparameterDecay(tf.keras.callbacks.Callback): """Base class for global_step/iterations-based optimizer decay callbacks.""" def __init__(self, hyperparameter, num_delay_steps=0, verbose=0): """Construct a new HyperparameterDecay. Args: hyperparameter: String specifying the optimizer attribute to decay. num_delay_steps: Integer specifying how many steps to wait before decaying the attribute. verbose: Integer. When > 1, the hyperparameter value is printed every epoch. """ self._hyperparameter = hyperparameter self._num_delay_steps = num_delay_steps self.verbose = verbose def on_train_begin(self, logs=None): self._optimizer = self.model.optimizer if not hasattr(self._optimizer, self._hyperparameter): raise ValueError('Optimizer must have a "{}" attribute.' .format(self._hyperparameter)) if not hasattr(self._optimizer, 'iterations'): raise ValueError('Optimizer must have a "iterations" attribute.') def on_epoch_begin(self, epoch, logs=None): if self.verbose > 0: value = float(tf.keras.backend.get_value(getattr(self._optimizer, self._hyperparameter))) print('\nEpoch {:05}: Current {} is {}.' .format(epoch + 1, self._hyperparameter, value)) def on_epoch_end(self, epoch, logs=None): if logs is not None: logs[self._hyperparameter] = tf.keras.backend.get_value( getattr(self._optimizer, self._hyperparameter)) def _get_global_step(self): return (tf.keras.backend.get_value(self._optimizer.iterations) - self._num_delay_steps) class PolynomialDecay(HyperparameterDecay): """Polynomial Optimizer Hyperparameter Schedule. Based on https://www.tensorflow.org/api_docs/python/tf/train/polynomial_decay The decay applies as follows for num_decay_steps steps when the global_step (i.e. optimizer.iterations) exceeds the num_delay_steps. step = global_step - num_delay_steps decayed_value = (init_value - final_value) * (1 - step / num_decay_steps) ^ (power) + final_value """ def __init__(self, hyperparameter, init_value, final_value, power, num_decay_steps, **kwargs): """Construct a new PolynomialDecay Callback. Args: hyperparameter: String specifying the optimizer attribute to decay. init_value: Float specifying initial value of the attribute. final_value: Float specifying value of attribute at the end of the decay. power: Float specifying power (exponent) of the polynomial decay. num_decay_steps: Integer, number of steps to decay the attribute. **kwargs: Keyword arguments for HyperparameterDecay. This includes num_delay_steps and verbose. """ super(PolynomialDecay, self).__init__(hyperparameter, **kwargs) self._init_value = init_value self._final_value = final_value self._power = power self._num_decay_steps = num_decay_steps def on_batch_begin(self, batch, logs=None): step = self._get_global_step() if step > 0 and step <= self._num_decay_steps: decayed_value = ((self._init_value - self._final_value) * (1 - step / self._num_decay_steps) ** (self._power) + self._final_value) setattr(self._optimizer, self._hyperparameter, decayed_value) class ExponentialDecay(HyperparameterDecay): """Exponential Optimizer Hyperparameter Decay Schedule. The decay applies as follows for num_decay_steps steps when the global_step (i.e. optimizer.iterations) exceeds the num_delay_steps. If num_decay_steps is not provided, it will keep decaying for the duration of training. When a decay rate and num_decay_steps is provided: step = min(global_step - num_delay_steps, num_decay_steps) decayed_value = init_value * decay_rate^step When a decay_rate and final_value are provided: step = global_step - num_delay_steps decayed_value = max(init_value * decay_rate^step, final_value) When a final value and num_decay_steps is provided: step = global_step - num_delay_steps decayed_value = init_value * (final_value / init_value) ^ (step / num_decay_steps) """ def __init__(self, hyperparameter, init_value, final_value=None, decay_rate=None, num_decay_steps=None, **kwargs): """Construct a new ExponentialDecay Callback. You must specify exactly two of final_value, decay_rate, and num_decay_steps. Args: hyperparameter: String specifying the optimizer attribute to decay. init_value: Float specifying initial value of the attribute. final_value: Float specifying value of attribute at the end of the decay. decay_rate: Float specifying the decay rate of the decay. num_decay_steps: Integer, number of steps to decay the attribute. **kwargs: Keyword arguments for HyperparameterDecay. This includes num_delay_steps and verbose. """ super(ExponentialDecay, self).__init__(hyperparameter, **kwargs) self._num_decay_steps = num_decay_steps # In theory, we could support more different combinations of final_value, # num_decay_steps, and decay_rate, but for the sake of clarity we will limit # this callback to the below combinations. if final_value and decay_rate and num_decay_steps: raise ValueError('You must specify exactly two of final_value, decay_rate' ', and num_decay_steps.') if final_value and decay_rate: self._decay_func = lambda step: max( # pylint: disable=g-long-lambda (init_value * (decay_rate ** step)), final_value) elif decay_rate and num_decay_steps: self._decay_func = lambda step: (init_value * decay_rate ** step) elif final_value and num_decay_steps: self._decay_func = lambda step: ( # pylint: disable=g-long-lambda init_value * (final_value / init_value) ** (float(step) / num_decay_steps)) else: raise ValueError('You must specify exactly two of final_value, decay_rate' ', and num_decay_steps.') def on_batch_begin(self, batch, logs=None): global_step = self._get_global_step() if (global_step > 0 and (not self._num_decay_steps or global_step <= self._num_decay_steps)): decayed_value = self._decay_func(global_step) setattr(self._optimizer, self._hyperparameter, decayed_value) ================================================ FILE: kfac/python/keras/optimizers.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """KFAC Optimizer for Keras.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import inspect import numbers import re from absl import logging from tensorflow.python.keras import backend import six import tensorflow.compat.v1 as tf from kfac.python.keras import utils from kfac.python.ops import optimizer from kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt # TODO(b/135110195): Support letting the user choose the TF KFAC optimizer. _KFAC_OPT_CLASS = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt # TODO(b/134945404): Change how default config args are retrieved. getfullargspec = inspect.getfullargspec if six.PY3 else inspect.getargspec _KFAC_ARGS = getfullargspec(optimizer.KfacOptimizer.__init__) _PERIODIC_KFAC_ARGS = getfullargspec(_KFAC_OPT_CLASS.__init__) _DEFAULT_KWARGS = dict(zip(reversed(_KFAC_ARGS.args), reversed(_KFAC_ARGS.defaults))) _DEFAULT_KWARGS.update(zip(reversed(_PERIODIC_KFAC_ARGS.args), reversed(_PERIODIC_KFAC_ARGS.defaults))) _MUTABLE_HYPER_PARAMS = {'learning_rate', 'momentum', 'damping', 'weight_decay_coeff', 'norm_constraint'} def _configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive): """Checks and fills in some required kwargs to use an adaptive mode. This will set up kfac_kwargs for adaptive, adapt_damping, and/or qmodel momentum, if needed. It will not check for train_batch or batch_size, as that check happens right before the minimize. It will set the following if not set by the user: If adaptive=True: - adapt_damping=True, momentum=None, momentum_type='qmodel' - The checks listed below. If adapt_damping=True: - use_passed_loss=True, and then it will get the loss_tensor from minimize. - update_damping_immediately=True - damping_adaptation_interval=5 if the user hasn't set this already. - invert_every=5 if the user hasn't set this already. If momentum_type='qmodel' or momentum_type='qmodel_fixedmu': - Ensures learning rate and momentum are None. Args: kfac_kwargs: dict of keyword arguments to be passed to PeriodicInvCovUpdateKfacOpt. adaptive: bool indicating the optimizer is in adaptive mode. """ if adaptive: kfac_kwargs.update({ 'adapt_damping': True, 'momentum': None, 'momentum_type': 'qmodel', }) if kfac_kwargs.get('momentum_type', 'regular').lower().startswith('qmodel'): if kfac_kwargs['learning_rate']: raise ValueError('learning_rate must be None to use adaptive/qmodel.') if kfac_kwargs.get('momentum', None): raise ValueError('momentum must be None to use adaptive/qmodel.') if kfac_kwargs.get('adapt_damping', False): defaults = {'use_passed_loss': True, 'update_damping_immediately': True} # This way, we keep the user's preferences and only replace missing items. defaults.update(kfac_kwargs) kfac_kwargs.update(defaults) if not ('invert_every' in kfac_kwargs and 'damping_adaptation_interval' in kfac_kwargs): # damping_adaptation_interval % invert_every must = 0 kfac_kwargs['invert_every'] = 5 kfac_kwargs['damping_adaptation_interval'] = 5 class Kfac(tf.keras.optimizers.Optimizer): """The KFAC Optimizer for Keras.""" def __init__(self, # pylint: disable=invalid-name _sentinel=None, learning_rate=None, damping=None, model=None, loss=None, loss_weights=None, fisher_approx=None, layer_collection=None, adaptive=False, train_batch=None, name=None, seed=None, **kfac_kwargs): """Construct a new KFAC optimizer. If you construct this Optimizer without a model with a loss, model and loss, or a layer_collection, you must call register_layers before using the optimizer. If you use adaptive, adapt_damping, or qmodel_momentum, this class will set up the required loss functions and tensors. You must pass the train_batch tensors as a tuple (x, y). If the batch_size cannot be inferred from the train_batch[0] tensor, you pass in the batch_size in the constructor. You may not use numpy arrays as input when using the adaptive mode. If you do not use minimize, you must also provide the loss_tensor. When using Distribution Strategy, K-FAC expects a loss tensor that is normalized only by the per-replica batch size, and not the total batch size, unlike what is commonly recommended. This means you cannot use K-FAC with a Distribution Strategy and model.fit at the same time, since model.fit does this scaling for you. Instead, use a custom training loop with Distribution Strategy (there are examples in the Github repo). Args: _sentinel: Used to prevent positional parameters. Internal, do not use. learning_rate: float or 0D Tensor. Required if not using adapt_damping. Refer to kfac.KfacOptimizer for a detailed description. damping: Required. float or 0D Tensor. Refer to kfac.KfacOptimizer for a detailed description. model: Keras model which this class will optimize. Currently, dense, Conv 1D/2D, and embedding are supported as trainable layers. loss: Keras (normal or serialized) loss function. Could be a list or a dictionary mapping layer names to (normal or serialized) loss functions. Currently, sparse/normal categorical/binary cross entropy and MSE are supported. loss_weights: An optional list of coefficients or a dictionary mapping layer names to the coefficient for each loss functions. If it is a list, there must be a the same number of coefficients as loss functions. If it is a dictionary and a coefficient is not given for a loss function, a coefficient of 1.0 will be used. fisher_approx: An optional list of approximations or a dictionary mapping layer name/class to fisher approximation type. If it is a list, there must be the same number of approximations as there are layers with trainable parameters. For each layer, the approximation is determined as follows. If fisher_approx is a dictionary, first we check if the name is in the dict, if it isn't found the layer class is checked, if it isn't found the default is used. When fisher_approx is a list, the order of the approximations must match the order of the layers with trainable parameters given by model.layers. None is a valid dict/list entry and indicates to use the default approximation for that layer. layer_collection: Only use this argument when you have an unsupported model architecture and so manually register the layers. Refer to kfac.KfacOptimizer for a detailed description. adaptive: Whether this optimizer is in adaptive mode or not. In adaptive mode, we set momentum_type='qmodel' and adapt_damping=True, so you must provide the damping (used as the initial value). learning_rate and momentum must be None. You must provide a train_batch and potentially a batch_size if we cannot infer the batch_size from the train_batch. train_batch: A tuple (input, label). The input must be a tensor or a list of tensors that you can call the model on. The label must be a tensor or list of tensors compatible with the loss_fn. See utils.get_loss_fn for the standard loss_fn we create, or you can provide a custom loss_fn. name: Optional name for operations created when applying gradients. Defaults to "kfac". seed: Optional integer specifying the TensorFlow random seed. To get deterministic behaviour, the seed needs to be set because the targets are sampled to approximate the fisher. **kfac_kwargs: Additional arguments to be passed to kfac.PeriodicInvCovUpdateKfacOpt (and then to kfac.KfacOptimizer). Note the "loss" argument for kfac.KfacOptimizer should be passed as "loss_tensor". Raises: ValueError: If clipvalue or clipnorm arguments are used. ValueError: If positional arguments are used (or _sentinel is used). ValueError: If damping is not provided. ValueError: If learning_rate or momentum are set with adaptive=True. """ if tf.executing_eagerly(): logging.warn('Eager mode appears to be enabled. Kfac is untested in ' 'eager mode.') if _sentinel: raise ValueError('Do not pass positional arguments, only use keyword ' 'arguments.') if damping is None: raise ValueError('Please provide a value for damping.') if 'clipvalue' in kfac_kwargs: raise ValueError('Argument "clipvalue" is not support.') if 'clipnorm' in kfac_kwargs: raise ValueError('Argument "clipnorm" is not supported. Use ' '"norm_constraint" instead.') super(Kfac, self).__init__(name=name) kfac_kwargs.update({'name': self._name, 'learning_rate': learning_rate, 'damping': damping}) _configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive) self._optimizer = None self._layer_collection = None self._model = model self._loss = loss self._have_tracked_vars = False self._tf_var_scope = self._name + '/tf_vars' # We use _kfac_kwargs and _config in various parts in the code below. # _kfac_kwargs is checked when we want to know only what the user passed. # _config is used when we want user selections with the default kwargs as a # fallback. self._kfac_kwargs = kfac_kwargs self._layer_collection_kwargs = { 'loss_weights': loss_weights, 'fisher_approx': utils.serialize_fisher_approx(fisher_approx), 'seed': seed, } self._config = _DEFAULT_KWARGS.copy() self._config.update(kfac_kwargs) self._config.update(self._layer_collection_kwargs) self._config['loss'] = utils.serialize_loss(loss) if 'loss_tensor' in self._kfac_kwargs: self._kfac_kwargs['loss'] = self._kfac_kwargs.pop('loss_tensor') self._mutable_hypers = _MUTABLE_HYPER_PARAMS.copy() if self._config['adapt_damping']: self._mutable_hypers.remove('damping') if self._config['momentum_type'].lower().startswith('qmodel'): self._mutable_hypers -= {'learning_rate', 'momentum'} for hp in self._mutable_hypers.copy(): if self._config[hp] is None: self._mutable_hypers.remove(hp) else: self._set_hyper(hp, self._config[hp]) if layer_collection: self.register_layers(layer_collection=layer_collection) if train_batch and self._kfac_kwargs.get('adapt_damping', False): self.register_train_batch(train_batch=train_batch) @property def name(self): # This settable property exists to avoid variable name scope conflicts. return self._name @name.setter def name(self, value): if self._optimizer: raise ValueError('Can\'t change the optimizer\'s name after the variables' ' are created') self._name = value self._config['name'] = value self._kfac_kwargs['name'] = value self._tf_var_scope = value + '/tf_vars' @property def optimizer(self): # We defer the creation of the optimizer for a few reasons. First, if the # user decides to use the model as a callable, we want to capture the latest # inbound node of the model. Also, this mimics the behaviour of existing # Keras optimizers, as all the variables are created on the first # apply_gradients call (unless the user tries to access this property). # Second, this reduces code duplication as we can use the super class's # _set_hypers and _create_hypers methods. Finally, if the user restores an # optimizer, this allows them to control the variable scope before the # variables are created (to avoid scope conflicts). if not self._optimizer: self._create_optimizer() return self._optimizer @property def layers(self): return self._layer_collection @property def mutable_hyperparameters(self): return self._mutable_hypers def register_layers(self, model=None, loss=None, layer_collection=None): if not layer_collection: if not loss and hasattr(model, 'loss'): loss = model.loss if not (model and loss): raise ValueError('Please provide a model with a loss, a model and loss,' ' or a LayerCollection') layer_collection = utils.get_layer_collection( model, loss, **self._layer_collection_kwargs) self._layer_collection = layer_collection self._kfac_kwargs['var_list'] = layer_collection.registered_variables def register_train_batch(self, train_batch, batch_size=None): """Configures the train_batch tuple and batch_size for adaptive damping.""" if not isinstance(train_batch, tuple): raise ValueError('You must provide the train_batch tuple of inputs to ' 'use adaptive/adapt_damping mode.') elif not all(isinstance(inp, tf.Tensor) for inp in train_batch): raise ValueError('You must use TF tensors as input.') self._kfac_kwargs['train_batch'] = train_batch if batch_size: self._kfac_kwargs['batch_size'] = batch_size elif 'batch_size' not in self._kfac_kwargs: inferred_batch_size = train_batch[0].shape.as_list()[0] if inferred_batch_size: self._kfac_kwargs['batch_size'] = inferred_batch_size else: raise ValueError('Could not infer batch_size from the train_batch. ' 'Please provide it in the optimizer constructor or ' 'through register_train_batch.') def minimize(self, loss, var_list, grad_loss=None, name=None): if (self._config['use_passed_loss'] and 'loss' not in self._kfac_kwargs): self._kfac_kwargs['loss'] = loss return self._call_and_track_vars( 'minimize', loss, var_list=var_list, grad_loss=grad_loss, name=name) def apply_gradients(self, grads_and_vars, name=None): return self._call_and_track_vars( 'apply_gradients', grads_and_vars, name=name) def get_updates(self, loss, params): return [self.minimize(loss, params)] def get_config(self): config = self._config.copy() for param in self._hyper: config[param] = self._serialize_hyperparameter(param) return config def _create_optimizer(self): """Initializes the hyperparameters and sets the self._optimizer property.""" if self._optimizer: return if not self._layer_collection: self.register_layers(self._model, self._loss) if self._config['adapt_damping']: if 'train_batch' not in self._kfac_kwargs: raise ValueError('Must provide a train_batch tuple to use adaptive ' 'damping. Use register_train_batch or pass it in ' 'during optimizer construction.') if 'loss_fn' not in self._kfac_kwargs: self._kfac_kwargs['loss_fn'] = utils.get_loss_fn( self._model, self._loss, loss_weights=self._config['loss_weights']) with tf.name_scope(self._name): with tf.init_scope(): # "iterations" property will create iterations if necessary. _ = self.iterations self._create_hypers() self._kfac_kwargs.update(self._hyper) try: # We use the TF 1 variable_scope instead of the TF 2 recommended # name_scope because we need to recover the variables created in this # scope, which is not possible with name_scope. with tf.variable_scope(self._tf_var_scope): self._optimizer = _KFAC_OPT_CLASS( layer_collection=self._layer_collection, **self._kfac_kwargs) except ValueError as e: msg = str(e) if re.search('Variable .* already exists', msg): raise ValueError( 'You may have instantiated a KFAC Optimizer with the same name as ' 'an existing one. Try resetting the default graph, instantiating ' 'the optimizer with a different name, or changing the optimizer\'s ' 'name.\nHere is the original ValueError:\n ' + msg) elif re.search('Found the following errors with variable registration' '.*gamma.*registered with wrong number of uses.*', msg): # We don't regex the name batch_normalization because the user could # have renamed the layer. We don't regex beta because they could have # used BatchNorm without the shift. raise ValueError( 'There may have been an issue registering BatchNormalization. Try ' 'using tf.keras.backend.set_learning_phase before model ' 'construction. An alternative solution is to use the unfused ' 'batchnorm implementation (pass the argument fused=False to ' 'BatchNormalization).\nHere is the original ValueError:\n ' + msg) else: raise e def _call_and_track_vars(self, method_name, *args, **kwargs): # We call _create_optimizer outside of the var_scope because # _create_optimizer also opens the same variable_scope. self._create_optimizer() with tf.variable_scope(self._tf_var_scope): kwargs['global_step'] = self.iterations update_op = getattr(self._optimizer, method_name)(*args, **kwargs) if not self._have_tracked_vars: # We rely on the variables created in a deterministic order for get and # set weights. Sorting the variables by name is not a reliable way to # get a deterministic order due to the way TF KFAC assigns variable names. for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self._tf_var_scope): backend.track_variable(var) self.weights.append(var) self._have_tracked_vars = True return update_op def _set_hyper(self, name, value): """Set hyper `name` to value. value must be numeric.""" if self._hypers_created: if not isinstance(self._hyper[name], tf.Variable): raise AttributeError("Can't set attribute: {}".format(name)) if not isinstance(value, numbers.Number): raise ValueError('Dynamic reassignment only supports setting with a ' 'number. tf.Tensors and tf.Variables can only be used ' 'before the internal kfac optimizer is created.') backend.set_value(self._hyper[name], value) else: super(Kfac, self)._set_hyper(name, value) ================================================ FILE: kfac/python/keras/saving_utils.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Saving/loading utilities for models created with the KFAC Optimizer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import json from absl import logging from tensorflow.python.keras.saving import hdf5_format import tensorflow.compat.v1 as tf from kfac.python.keras import optimizers # This optional h5py import allows users to import all of tensorflow_kfac # without h5py. The ImportError is raised manually if they try to use load_model # without h5py. This follows the Keras save.py style: # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/saving/save.py try: import h5py # pylint: disable=g-import-not-at-top except ImportError: h5py = None def _compile_args_from_training_config(training_config, custom_objects=None): """Return model.compile arguments from training config.""" if custom_objects is None: custom_objects = {} optimizer_config = training_config['optimizer_config'] optimizer = tf.keras.optimizers.deserialize( optimizer_config, custom_objects=custom_objects) # Recover loss functions and metrics. loss_config = training_config['loss'] # Deserialize loss class. if isinstance(loss_config, dict) and 'class_name' in loss_config: loss_config = tf.keras.losses.get(loss_config) loss = tf.nest.map_structure( lambda obj: custom_objects.get(obj, obj), loss_config) metrics = tf.nest.map_structure( lambda obj: custom_objects.get(obj, obj), training_config['metrics']) weighted_metrics = tf.nest.map_structure( lambda obj: custom_objects.get(obj, obj), training_config.get('weighted_metrics', None)) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] return dict(optimizer=optimizer, loss=loss, metrics=metrics, weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) def load_model(filepath, custom_objects=None, optimizer_name=None): """Loads and compiles a Keras model saved as an HDF5 file. Same as tf.keras.model.load_model, except it will always compile the model and instantiate the Kfac optimizer correctly. If you do not want the model to be compiled, or saved without the optimizer, use tf.keras.models.load_model instead. Example: ```python: import tensorflow as tf import kfac model = tf.keras.Model(...) loss = tf.keras.losses.MSE() # could be a serialized loss function optimizer = kfac.keras.optimizers.Kfac(0.001, 0.01, model=model, loss=loss) model.compile(optimizer, loss) model.fit(...) model.save('saved_model.hdf5') # or use tf.keras.models.save_model ... loaded_model = kfac.keras.saving_utils.load_model('saved_model.hdf5') loaded_model.fit(...) ``` Args: filepath: One of the following: - String, path to the saved model - `h5py.File` object from which to load the model custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. Kfac will be added to this dictionary automatically. optimizer_name: Optional string that specifies what variable scope you want the KFAC variables to be created in. Useful if you have multiple KFAC optimizers on one graph. Raises: ImportError: If h5py was not imported. Returns: A compiled Keras model with the Kfac optimizer correctly initialized. """ if h5py is None: raise ImportError('`load_model` requires h5py.') if not custom_objects: custom_objects = {} custom_objects['Kfac'] = optimizers.Kfac should_open_file = not isinstance(filepath, h5py.File) model_file = h5py.File(filepath, mode='r') if should_open_file else filepath model = tf.keras.models.load_model( model_file, custom_objects=custom_objects, compile=False) # Code below is current as of 2019-06-20 and may break due to future changes. # github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/saving/hdf5_format.py try: training_config = model_file.attrs.get('training_config') if hasattr(training_config, 'decode'): training_config = training_config.decode('utf-8') if training_config is None: raise ValueError('No training configuration found in save file, meaning ' 'the model was not compiled. Please use ' 'tf.keras.models.load_model instead.') training_config = json.loads(training_config) model.compile(**_compile_args_from_training_config(training_config, custom_objects)) model.optimizer.register_layers(model) if optimizer_name: model.optimizer.name = optimizer_name if 'optimizer_weights' in model_file: # Build train function (to get weight updates). # Models that aren't graph networks must wait until they are called # with data to _make_train_function() and so can't load optimizer # weights. model._make_train_function() # pylint: disable=protected-access opt_weight_vals = hdf5_format.load_optimizer_weights_from_hdf5_group( model_file) try: model.optimizer.set_weights(opt_weight_vals) except ValueError: logging.warn('Error in loading the saved optimizer state. As a ' 'result, your model is starting with a freshly ' 'initialized optimizer.') finally: if should_open_file: model_file.close() return model ================================================ FILE: kfac/python/keras/utils.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility Functions for using KFAC with Keras Objects.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import six import tensorflow.compat.v1 as tf from kfac.python.ops import layer_collection as kfac_layer_collection from kfac.python.ops.tensormatch import tensorflow_graph_util layers = tf.keras.layers losses = tf.keras.losses activations = tf.keras.activations K = tf.keras.backend # Added when serializing layer class names to prevent serialized class names # from clashing with user-defined layer names. _CLASS_NAME_PREFIX = 'kfac_class_' _KERAS_LOSS_TO_KFAC_REGISTER_FUNC = { 'sparsecategoricalcrossentropy': kfac_layer_collection.LayerCollection .register_softmax_cross_entropy_loss, 'categoricalcrossentropy': kfac_layer_collection.LayerCollection .register_softmax_cross_entropy_loss, 'binarycrossentropy': kfac_layer_collection.LayerCollection .register_sigmoid_cross_entropy_loss, } def get_parent(node): """Retrieves the parent tf.Tensor of node in the computation graph. Args: node: A tf.Tensor. Raises: ValueError: If the node has more than one input op. ValueError: If the node has more than one parent tf.Tensor. Returns: The parent tensor of the node on the computation graph. """ edge = tensorflow_graph_util.expand_inputs(node) if len(edge) != 1: raise ValueError('{} has more than one input op.'.format(node)) parent = tensorflow_graph_util.expand_inputs(edge[0]) if len(parent) != 1: raise ValueError('{} has more than one parent tensor.'.format(node)) return parent[0] def serialize_loss(loss): """Serialize a valid Keras Kfac loss argument.""" def serialize(x): return x if isinstance(x, six.string_types) else losses.serialize(x) if not loss or isinstance(loss, six.string_types): return loss elif isinstance(loss, dict): return {k: serialize(v) for k, v in loss.items()} elif isinstance(loss, list): return [serialize(v) for v in loss] else: return losses.serialize(loss) def serialize_fisher_approx(fisher_approx): """Serialize a valid fisher approximation dict or list.""" def serialize(key): return (key if isinstance(key, six.string_types) else _CLASS_NAME_PREFIX + key.__name__) if isinstance(fisher_approx, dict): fisher_approx = {serialize(k): v for k, v in fisher_approx.items()} return fisher_approx def _get_verified_dict(container, container_name, layer_names): """Verifies that loss_weights/fisher_approx conform to their specs.""" if container is None or container == {}: # pylint: disable=g-explicit-bool-comparison # The explicit comparison prevents empty lists from passing. return {} elif isinstance(container, dict): string_keys = { str(k) for k in container if isinstance(k, six.string_types) and not k.startswith(_CLASS_NAME_PREFIX) } if string_keys - set(layer_names): raise ValueError('There is a {} without a matching layer' .format(container_name)) return container elif isinstance(container, list): if len(layer_names) != len(container): raise ValueError('Number of {} and layers don\'t match.' .format(container_name)) return dict(zip(layer_names, container)) else: raise ValueError('{} must be a list or dict'.format(container_name)) def register_layer(layer_collection, layer, fisher_approx=None, **kwargs): """Get layer collection with all layers and loss registered. Args: layer_collection: LayerCollection object on which the layer will be registered. layer: Keras layer to register with the layer_collection. fisher_approx: Option string specifying the fisher approximation type. **kwargs: Keyword arguments to be forwarded to the layer registration function. Raises: ValueError: If there is a layer with trainable parameters that isn't Conv1D, Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding. ValueError: If convolutional layers don't use the "channels_last" format. Returns: A kfac.LayerCollection with the model's layers and loss registered. """ # The inbound_nodes property is currently deprecated, but appears to be # supported in non-eager TF 1.x. This may change. # If there are multiple inbound_nodes, it means the model was used as a # callable (i.e. y = model(x)). We assume the inputs/outputs from the call # need to be registered and not the nodes from the original built model or # any other previous calls, since layers can't be used multiple times # (RNN-style) with Keras KFAC. node = layer.inbound_nodes[-1] pre_activation_output = node.output_tensors if hasattr(layer, 'activation') and layer.activation != activations.linear: pre_activation_output = get_parent(pre_activation_output) # This will allow unsupported layers to be in our model as long as KFAC # doesn't have to minimize with respect to those parameters. if layer.count_params() and layer.trainable: if any(isinstance(tensor, (list, tuple)) for tensor in (node.input_tensors, node.output_tensors)): raise ValueError('Individual layers can only have 1 input_tensor and 1 ' 'output tensor. You are likely using an unsupported ' 'layer type. Error on layer {}'.format(layer)) weights = layer.trainable_weights kwargs.update({ 'inputs': node.input_tensors, 'outputs': pre_activation_output, 'params': weights if len(weights) > 1 else weights[0], 'approx': fisher_approx, }) # TODO(b/133849249) Support RNNs and other shared weight layers. if isinstance(layer, layers.Dense): layer_collection.register_fully_connected(**kwargs) elif isinstance(layer, layers.Embedding): layer_collection.register_fully_connected(dense_inputs=False, **kwargs) elif isinstance(layer, (layers.BatchNormalization, layers.LayerNormalization)): if not layer.scale: # With Batch/Layer Normalization, the user can specify if they want # the input to be scaled and/or shifted after it is normalized. raise ValueError('Kfac currently does not support batch/layer ' 'normalization with scale=False. Error on layer {}' .format(layer)) # Undo batchnorm by subtracting the shift and diving by scale. kwargs['inputs'] = ((kwargs['outputs'] - weights[1]) / weights[0] if layer.center else kwargs['outputs'] / weights) layer_collection.register_scale_and_shift(**kwargs) # A learning_phase of 1 or 0 means it's been set. False means it hasn't. is_phase_set = K.get_value(K.learning_phase()) != False # pylint: disable=g-explicit-bool-comparison if hasattr(layer, 'fused') and layer.fused and not is_phase_set: # For the fused implementation of the BatchNormalization, there are # two ops: one for training and one for inference. When the # learning_phase is set, during layer creation, there is a # tf_utils.smart_cond that will only create one of the ops. When the # learning_phase is not set, it will create a tf.cond with both ops as # branches. So, when learning_phase is not set, we must add a "use" # for the gamma/beta variables to account for there being two ops that # are consumers of the variables. Linked below is the smart_cond in # BatchNormalization: # https://github.com/tensorflow/tensorflow/blob/59217f581fdef4e5469a98b62e38f851eac88688/tensorflow/python/keras/layers/normalization.py#L513 # Updated 2019-06-22. layer_collection._add_uses(weights, 1) # pylint: disable=protected-access elif all(hasattr(layer, a) for a in ('strides', 'padding', 'dilation_rate')): if layer.data_format != 'channels_last': raise ValueError('KFAC currently only supports the "channels_last" ' 'data format for convolutional layers. Error on ' 'layer {}'.format(layer)) kwargs['padding'] = layer.padding.upper() kwargs['strides'] = [1] + list(layer.strides) + [1] kwargs['dilations'] = [1] + list(layer.dilation_rate) + [1] if isinstance(layer, layers.Conv2D): layer_collection.register_conv2d(**kwargs) elif isinstance(layer, layers.Conv1D): layer_collection.register_conv1d(**kwargs) # Depthwise and Separable Conv2D are not supported yet because they are # experimental in tensorflow_kfac. else: raise ValueError('Unsupported convolutional layer type: {}' .format(layer)) # TODO(b/133849240): Support registering any convolution type. else: raise ValueError('Unsupported layer type: {}'.format(layer)) # TODO(b/133849243): Support registering any generic layer type. def register_loss(layer_collection, layer, loss, **kwargs): """Registers the loss with the layer for the layer_collection. Args: layer_collection: LayerCollection object on which the layer and loss will be registered. layer: Keras layer whose outputs will be used with the loss function. loss: Keras (normal or serialized) loss function. Currently, sparse/normal categorical/binary cross entropy and MSE are supported. **kwargs: Keyword arguments to be forwarded to the function that registers the loss. A couple of notable ones include coeff (the weight of the loss) and seed (the seed used when sampling from the output distribution). Raises: ValueError: If a loss function other than MSE and cross entropy variants is used. Raises: ValueError: If a loss function other than MSE and cross entropy variants is used. """ node = layer.inbound_nodes[-1] pre_activation_output = node.output_tensors if hasattr(layer, 'activation') and layer.activation != activations.linear: pre_activation_output = get_parent(pre_activation_output) # A Keras loss can be a callable class or a function. Their serialized # forms differ. The logic below normalizes these difference. This will # not work for custom losses (we do not intend to support custom loss # functions for now). if not isinstance(loss, six.string_types): loss = losses.serialize(loss) if isinstance(loss, dict): loss = loss['class_name'] loss = loss.replace('_', '').lower() if loss in ('meansquarederror', 'mse'): # We use the actual output here instead of the pre-activations because # MSE is computed with the output. For the logit loss functions, # tensorflow_kfac needs the pre-activations. layer_collection.register_squared_error_loss(layer.output, **kwargs) elif loss in _KERAS_LOSS_TO_KFAC_REGISTER_FUNC: _KERAS_LOSS_TO_KFAC_REGISTER_FUNC[loss]( layer_collection, logits=pre_activation_output, **kwargs) else: raise ValueError('Unsupported loss function: {}'.format(loss)) def get_layer_collection(model, loss=None, loss_weights=None, fisher_approx=None, layer_collection=None, seed=None): """Get layer collection with all layers and loss registered. Args: model: Keras model whose layers to register. Currently, Conv1D, Conv2D, Dense, BatchNormalization, LayerNormalization and Embedding layers are supported in a Functional or Sequential model. Other layer types are supported as long as they aren't trainable (or don't have weights). Nested models are supported. loss: Optional Keras (normal or serialized) loss function. Could be a list or a dictionary mapping layer names to (normal or serialized) loss functions. if there are multiple losses Currently, sparse/normal categorical/binary cross entropy and MSE are supported. You must register at least one loss with the layer collection before it can be used. loss_weights: An optional list of coefficients or a dictionary mapping layer names to the coefficient for each loss function. If it is a list, there must be a the same number of coefficients as loss functions. If it is a dictionary and a coefficient is not given for a loss function, a coefficient of 1.0 will be used. fisher_approx: An optional list of approximations or a dictionary mapping layer name/class to fisher approximation type. If it is a list, there must be the same number of approximations as there are layers with trainable parameters. For each layer, the approximation is determined as follows: if fisher_approx is a dictionary, first we check if the name is in the dict, if it isn't found the layer class is checked, if that isn't found the default is used. When fisher_approx is a list, the order of the approximations must match the order of the layers with trainable parameters given by model.layers. None is a valid dict/list entry and indicates to use the default approximation for that layer. layer_collection: Optional LayerCollection object on which the model and loss will be registered. seed: Optional integer specifing the TensorFlow random seed. To get deterministic behaviour, the seed needs to be set because the targets are sampled to approximate the fisher. Raises: ValueError: If there is a layer with trainable parameters that isn't Conv1D, Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding. ValueError: If a loss function other than MSE and cross entropy variants is used. ValueError: If there isn't a one-to-one correspondence between loss/loss_weights and output layers, or if loss_weights isn't a list/dict. ValueError: If convolutional layers don't use the "channels_last" format. Returns: A kfac.LayerCollection with the model's layers and loss registered. """ if not layer_collection: layer_collection = kfac_layer_collection.LayerCollection() if not loss: loss = {} elif isinstance(loss, dict): if set(model.output_names) != set(loss.keys()): raise ValueError('Output layer names and loss dict keys don\'t match' ' \nmodel.output_names: {} \nloss dict keys: {}' .format(model.output_names, loss.keys())) elif isinstance(loss, list): if len(model.output_names) != len(loss): raise ValueError('Number of loss dict items doesn\'t match number of ' 'output layers. \nmodel.output_names: {} \nloss list: ' '{}'.format(model.output_names, loss)) loss = dict(zip(model.output_names, loss)) else: if len(model.output_names) > 1: raise ValueError('More output layers than losses. \n' 'model.output_names: {} \nloss: {}' .format(model.output_names, loss)) # When the model is used as a callable, the model's output_names may not # match the actual output layer's name. In the one output case, we always # want the last layer, so we use the last layer's name. loss = {model.layers[-1].name: loss} # We want to do a left-to-right depth-first traversal of the model to get the # correct flattened order of the layers. The order only matters for the # fisher_approx in list form. flattened_layers = [] layer_stack = model.layers[::-1] while layer_stack: layer = layer_stack.pop() if hasattr(layer, 'layers'): if layer.name in loss: if len(layer.output_names) > 1: raise ValueError('Nested models with multiple outputs are ' 'unsupported.') loss[layer.output_names[0]] = loss.pop(layer.name) layer_stack += layer.layers[::-1] else: flattened_layers.append(layer) trainable_layer_names = [l.name for l in flattened_layers if l.count_params() and l.trainable] fisher_approx = _get_verified_dict(fisher_approx, 'fisher_approx', trainable_layer_names) # The Optimizer class passes in a serialized fisher_approx dictionary, but the # user may not. We serialize it so we can use it uniformly. fisher_approx = serialize_fisher_approx(fisher_approx) loss_weights = _get_verified_dict(loss_weights, 'loss_weights', model.output_names) for layer in flattened_layers: if layer.name in fisher_approx: approx = fisher_approx[layer.name] else: approx = fisher_approx.get( _CLASS_NAME_PREFIX + layer.__class__.__name__, None) register_layer(layer_collection, layer, fisher_approx=approx) if layer.name in loss: register_loss(layer_collection=layer_collection, layer=layer, loss=loss[layer.name], coeff=loss_weights.get(layer.name, 1.0), seed=seed) return layer_collection def get_loss_fn(model, loss, training=None, loss_weights=None, reduce_fn=tf.reduce_mean, name='loss'): """Creates a loss function to be used for KFAC's adaptive damping. This allows Keras KFAC to automatically create the loss function to use for adaptive_damping. This function would also be useful for a custom training loop that uses adaptive_damping. The returned loss function currently does not support masks or sample_weights. Currently, if you use a categorical crossentropy loss, due to the implementation of tf.keras.losses.*_crossentropy, it will grab the logits whether you use a softmax at the end of your model or not. This is true as of August 1, 2019. Code below: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py#L4322 Args: model: tf.keras.Model model that will be used with the inputs to the returned loss_fn. loss: Potentially serialized tf.keras.losses.* loss function(s)/class(es). If the model has multiple outputs, this must be a list of losses that matches the order of model.outputs, or a dictionary with names matching output_names. Must accept kwargs y_pred and y_true. Note that if your model's output are logits, you should pass a callable Keras with from_logits=True. This function could be a non-Keras loss, but it is untested in this case. training: Boolean indicating whether or not the loss is used in training or test time. This is necessary to set the proper mode for batch norm and dropout layers. If None then falls back to Keras behavior of calling the model without passing a value for training. loss_weights: If you have multiple losses, a list or dictionaryof weights for each loss. A default value of 1.0 is given for losses that don't have a weight when a dictionary is passed. reduce_fn: The function that will be used to aggregate the loss tensor. tf.reduce_mean by default. You may replace this with the identity if your loss does a reduction by default. Depending on how you compute your loss in a distributed setting, you may want to modify this function (for example, if you sum across replicas, then the reduce_fn might be lambda x: tf.reduce_sum(x) * (1.0 / global_batch_size). name: Name scope for the loss_fn ops. Raises: ValueError: If the loss is a dictionary. Returns: A function that takes inputs and optionally a prediction and will return a loss. This can be used as the KFAC loss_fn for adaptive damping. """ if isinstance(loss, six.string_types): loss = losses.deserialize(loss) elif isinstance(loss, dict): loss = [loss[n] for n in model.output_names] if isinstance(loss, list): loss = [losses.deserialize(l) if isinstance(l, six.string_types) else l for l in loss] if isinstance(loss_weights, dict): loss_weights = [loss_weights.get(n, 1.0) for n in model.output_names] def loss_fn(inputs, prediction=None): """Computes loss for a model given inputs. This function is meant to be used with K-FAC's adaptive damping, which is why the prediction is optional (since K-FAC wants to compute the loss just given inputs). Args: inputs: A tuple with (model_input(s), label(s)), where both elements are tensors or lists/tuples of tensors. prediction: The output of the model given the inputs. If this isn't, provided, the prediction will be computed via prediction = model(inputs[0]) Returns: A tensor with the total reduced loss including regularization and other layer specific losses. """ with tf.name_scope(name): x, y = inputs if prediction is None: if training is not None: prediction = model(x, training=training) else: prediction = model(x) if isinstance(prediction, (tuple, list)): reduced_losses = [reduce_fn(fn(y_pred=pred_i, y_true=y_i)) for fn, pred_i, y_i in zip(loss, prediction, y)] if loss_weights: reduced_losses = [l * w for l, w in zip(reduced_losses, loss_weights)] total_loss = tf.add_n(reduced_losses) else: total_loss = reduce_fn(loss(y_pred=prediction, y_true=y)) # Adds regularization penalties and other custom layer specific losses. if model.losses: total_loss += tf.add_n(model.losses) return total_loss return loss_fn ================================================ FILE: kfac/python/kernel_tests/data_reader_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for CachedDataReader class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf from kfac.python.ops.kfac_utils import data_reader class DataReaderTest(tf.test.TestCase): def test_read_batch(self): max_batch_size = 10 batch_size_schedule = [2, 4, 6, 8] data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.) var_data = data_reader.CachedDataReader( (data_set,), max_batch_size) cur_batch_size = tf.placeholder( shape=(), dtype=tf.int32, name='cur_batch_size') # Force create the ops data = var_data(cur_batch_size)[0] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) for batch_size in batch_size_schedule: data_ = sess.run( data, feed_dict={cur_batch_size: batch_size}) self.assertEqual(len(data_), batch_size) self.assertEqual(len(data_[0]), 784) def test_cached_batch(self): max_batch_size = 100 data_set = tf.random_uniform(shape=(max_batch_size, 784), maxval=1.) var_data = data_reader.CachedDataReader( (data_set,), max_batch_size) cur_batch_size = tf.placeholder( shape=(), dtype=tf.int32, name='cur_batch_size') # Force create the ops data = var_data(cur_batch_size)[0] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) data_ = sess.run(data, feed_dict={cur_batch_size: 25}) stored_data_ = sess.run(var_data.cached_batch)[0] self.assertListEqual(list(data_[1]), list(stored_data_[1])) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/estimator_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for kfac.estimator.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np import tensorflow.compat.v1 as tf from kfac.python.ops import estimator from kfac.python.ops import fisher_factors as ff from kfac.python.ops import layer_collection as lc from kfac.python.ops import utils # We need to set these constants since the numerical values used in the tests # were chosen when these used to be the defaults. ff.set_global_constants(zero_debias=False) _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] class EstimatorTest(tf.test.TestCase): def setUp(self): self._graph = tf.Graph() with self._graph.as_default(): self.layer_collection = lc.LayerCollection() self.inputs = tf.random_normal((2, 2), dtype=tf.float32) self.weights = tf.get_variable("w", shape=(2, 2), dtype=tf.float32) self.bias = tf.get_variable( "b", initializer=tf.zeros_initializer(), shape=(2, 1)) self.output = tf.matmul(self.inputs, self.weights) + self.bias # Only register the weights. self.layer_collection.register_fully_connected( params=(self.weights,), inputs=self.inputs, outputs=self.output) self.outputs = tf.tanh(self.output) self.targets = tf.zeros_like(self.outputs) self.layer_collection.register_categorical_predictive_distribution( logits=self.outputs, targets=self.targets) def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection ) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights, self.bias], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection ) est.make_vars_and_create_op_thunks() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) est.make_vars_and_create_op_thunks() @tf.test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) est.make_vars_and_create_op_thunks() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="not_a_real_mode") est.make_vars_and_create_op_thunks() def testGradientsModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="gradients") est.make_vars_and_create_op_thunks() def testEmpiricalModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="empirical") est.make_vars_and_create_op_thunks() def testCurvaturePropModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="curvature_prop") est.make_vars_and_create_op_thunks() def testExactModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="exact") est.make_vars_and_create_op_thunks() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = tf.train.get_or_create_global_step() (cov_variable_thunks, cov_update_op_thunks, _, _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ fisher_factor.cov for fisher_factor in self.layer_collection.get_factors() ] cov_update_op = tf.case([(tf.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(tf.global_variables_initializer()) initial_cov_values = sess.run(cov_matrices) # Ensure there's one update per covariance matrix. self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) # Test is no-op if only 1 covariance matrix. assert len(cov_matrices) > 1 for i in range(len(cov_matrices)): # Compare new and old covariance values new_cov_values = sess.run(cov_matrices) is_cov_equal = [ np.allclose(initial_cov_value, new_cov_value) for (initial_cov_value, new_cov_value) in zip(initial_cov_values, new_cov_values) ] num_cov_equal = sum(is_cov_equal) # Ensure exactly one covariance matrix changes per step. self.assertEqual(num_cov_equal, len(cov_matrices) - i) # Run all covariance update ops. sess.run(cov_update_op) sess.run(increment_global_step) def test_round_robin_placement(self): """Check if the ops and variables are placed on devices correctly.""" with self._graph.as_default(): fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0, cov_devices=["/cpu:{}".format(i) for i in range(2)], inv_devices=["/cpu:{}".format(i) for i in range(2)]) # Construct an op that executes one covariance update per step. (cov_update_thunks, inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( scope="test") cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") cov_matrices = [ fisher_factor._cov._var for fisher_factor in self.layer_collection.get_factors() ] inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] self.assertEqual(cov_matrices[0].device, "/device:CPU:0") self.assertEqual(cov_matrices[1].device, "/device:CPU:1") # Inverse matrices need to be explicitly placed. self.assertEqual(inv_matrices[0].device, "") self.assertEqual(inv_matrices[1].device, "") def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = tf.train.get_or_create_global_step() (cov_variable_thunks, _, inv_variable_thunks, inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() for thunk in inv_variable_thunks: thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] inv_update_op = tf.case([(tf.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(tf.global_variables_initializer()) initial_inv_values = sess.run(inv_matrices) # Ensure there's one update per inverse matrix. This is true as long as # there's no fan-in/fan-out or parameter re-use. self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) # Test is no-op if only 1 invariance matrix. assert len(inv_matrices) > 1 # Assign each covariance matrix a value other than the identity. This # ensures that the inverse matrices are updated to something different as # well. sess.run([ fisher_factor._cov.add_to_average( 2 * tf.eye(int(fisher_factor._cov_shape[0]))) for fisher_factor in self.layer_collection.get_factors() ]) for i in range(len(inv_matrices)): # Compare new and old inverse values new_inv_values = sess.run(inv_matrices) is_inv_equal = [ np.allclose(initial_inv_value, new_inv_value) for (initial_inv_value, new_inv_value) in zip(initial_inv_values, new_inv_values) ] num_inv_equal = sum(is_inv_equal) # Ensure exactly one inverse matrix changes per step. self.assertEqual(num_inv_equal, len(inv_matrices) - i) # Run all inverse update ops. sess.run(inv_update_op) sess.run(increment_global_step) if __name__ == "__main__": tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/graph_search_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for tensormatch/graph_search.py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import tensorflow.compat.v1 as tf from kfac.python.ops import fisher_blocks as fb from kfac.python.ops import layer_collection as lc from kfac.python.ops import optimizer from kfac.python.ops.tensormatch import graph_search as gs def _build_model(): w = tf.get_variable('W', [10, 10]) b_1 = tf.get_variable('b_1', [ 10, ]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) pre_bias_0 = tf.matmul(x, w) pre_bias_1 = tf.matmul(y, w) out_0 = pre_bias_0 + b_0 # pylint: disable=unused-variable out_1 = pre_bias_1 + b_1 # pylint: disable=unused-variable tensor_dict = {} tensor_dict['w'] = w tensor_dict['b_0'] = b_0 tensor_dict['b_1'] = b_1 tensor_dict['x'] = x tensor_dict['y'] = y tensor_dict['pre_bias_0'] = pre_bias_0 tensor_dict['pre_bias_1'] = pre_bias_1 tensor_dict['out_0'] = out_0 tensor_dict['out_1'] = out_1 return tensor_dict def _build_mock_records(): tensor_dict = _build_model() weight_record = gs.MatchRecord( record_type=gs.RecordType.fully_connected, params=tensor_dict['w'], tensor_set={ tensor_dict['x'], tensor_dict['w'], tensor_dict['pre_bias_0'] }) weight_and_bias_0_record = gs.MatchRecord( record_type=gs.RecordType.fully_connected, params=(tensor_dict['w'], tensor_dict['b_0']), tensor_set={ tensor_dict['x'], tensor_dict['w'], tensor_dict['pre_bias_0'], tensor_dict['b_0'], tensor_dict['out_0'] }) bias_0_record = gs.MatchRecord( record_type=gs.RecordType.fully_connected, params=tensor_dict['b_0'], tensor_set={ tensor_dict['pre_bias_0'], tensor_dict['b_0'], tensor_dict['out_0'] }) weight_and_bias_1_record = gs.MatchRecord( record_type=gs.RecordType.fully_connected, params=(tensor_dict['w'], tensor_dict['b_1']), tensor_set={ tensor_dict['y'], tensor_dict['w'], tensor_dict['pre_bias_1'], tensor_dict['b_1'], tensor_dict['out_1'] }) record_list_dict = collections.defaultdict(list) for record in [ weight_record, weight_and_bias_0_record, bias_0_record, weight_and_bias_1_record ]: record_list_dict[record.params].append(record) return tensor_dict, dict(record_list_dict) def assert_fisher_blocks_match(test_case, layer_collection_a, layer_collection_b): """Check that two `LayerCollection`s have matching fisher_blocks.""" fisher_blocks_a = layer_collection_a.fisher_blocks fisher_blocks_b = layer_collection_b.fisher_blocks test_case.assertSetEqual( set(fisher_blocks_a.keys()), set(fisher_blocks_b.keys())) for parameters, block_a in fisher_blocks_a.items(): block_b = fisher_blocks_b[parameters] test_case.assertEqual(type(block_a), type(block_b)) if hasattr(block_a, '_inputs'): test_case.assertEqual(block_a._inputs, block_b._inputs) # pylint: disable=protected-access test_case.assertEqual(block_a._outputs, block_b._outputs) # pylint: disable=protected-access else: test_case.assertEqual(block_a._params, block_b._params) # pylint: disable=protected-access def sparse_softmax_cross_entropy(labels, logits, num_classes, weights=1.0, label_smoothing=0.1): """Softmax cross entropy with example weights, label smoothing.""" assert_valid_label = [ tf.assert_greater_equal(labels, tf.cast(0, dtype=tf.int64)), tf.assert_less(labels, tf.cast(num_classes, dtype=tf.int64)) ] with tf.control_dependencies(assert_valid_label): labels = tf.reshape(labels, [-1]) dense_labels = tf.one_hot(labels, num_classes) loss = tf.losses.softmax_cross_entropy( onehot_labels=dense_labels, logits=logits, weights=weights, label_smoothing=label_smoothing) return loss class GraphSearchTestCase(tf.test.TestCase): def testRegisterLayers(self): """Ensure graph search can find a single layer network.""" with tf.Graph().as_default(): layer_collection = lc.LayerCollection() # Construct a 1-layer model. inputs = tf.ones((2, 1)) * 2 weights = tf.get_variable( 'w', shape=(1, 1), dtype=tf.float32, initializer=tf.random_normal_initializer) bias = tf.get_variable( 'b', initializer=tf.zeros_initializer(), shape=(1, 1)) non_variable_bias = tf.ones((1, 1)) output = tf.matmul(inputs, weights) + bias + non_variable_bias logits = tf.tanh(output) # Register posterior distribution. Graph search will infer variables # needed to construct this. layer_collection.register_categorical_predictive_distribution(logits) # Register variables. gs.register_layers(layer_collection, tf.trainable_variables()) # Ensure 1-layer got registered. self.assertEqual( [(weights, bias)], list(layer_collection.fisher_blocks.keys())) self.assertEqual(1, len(layer_collection.losses)) def test_register_records_order(self): """Ensure records are always registered in the same order.""" with tf.Graph().as_default(): data = {'inputs': tf.zeros([10, 4]), 'outputs': tf.zeros([10, 3]), 'dense_inputs': True} params1 = tf.get_variable('w1', [4, 3]) record1 = gs.MatchRecord( gs.RecordType.fully_connected, params1, set(), data=data) params2 = (tf.get_variable('w2', [4, 3]), tf.get_variable('b2', [3])) record2 = gs.MatchRecord( gs.RecordType.fully_connected, params2, set(), data=data) # Create a dict of records. records = collections.OrderedDict() records[params1] = [record1] records[params2] = [record2] # Register variables. layer_collection = lc.LayerCollection(name='lc1') gs.register_records(layer_collection, records) # Ensure order matches lexicographic order. self.assertEqual([params2, params1], list(layer_collection.fisher_blocks.keys())) # Create a dict of records in a different order. records = collections.OrderedDict() records[params2] = [record2] records[params1] = [record1] # Register variables. layer_collection = lc.LayerCollection(name='lc2') gs.register_records(layer_collection, records) # Ensure order matches lexicographic order. self.assertEqual([params2, params1], list(layer_collection.fisher_blocks.keys())) def test_multitower_examples_model(self): """Ensure graph search runs properly on a multitower setup. This test uses linear_model from examples/convnets. """ with tf.Graph().as_default(): def linear_model(images, labels, num_classes): """Creates a linear model. Args: images: The input image tensors, a tensor of size (batch_size x height_in x width_in x channels). labels: The sparse target labels, a tensor of size (batch_size x 1). num_classes: The number of classes, needed for one-hot encoding (int). Returns: loss: The total loss for this model (0-D tensor). logits: Predictions for this model (batch_size x num_classes). """ images = tf.reshape(images, [images.shape[0], -1]) logits = tf.layers.dense(images, num_classes, name='logits') loss = sparse_softmax_cross_entropy(labels, logits, num_classes) return loss, logits model = linear_model layer_collection = lc.LayerCollection() num_towers = 2 batch_size = num_towers num_classes = 2 # Set up data. images = tf.random_uniform(shape=[batch_size, 32, 32, 1]) labels = tf.random_uniform( dtype=tf.int64, shape=[batch_size, 1], maxval=num_classes) tower_images = tf.split(images, num_towers) tower_labels = tf.split(labels, num_towers) # Build model. losses = [] logits = [] for tower_id in range(num_towers): tower_name = 'tower%d' % tower_id with tf.name_scope(tower_name): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): current_loss, current_logits = model( tower_images[tower_id], tower_labels[tower_id], num_classes + 1) layer_collection.register_categorical_predictive_distribution( current_logits, name='logits') losses.append(current_loss) logits.append(current_logits) # Run the graph scanner. with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): gs.register_layers(layer_collection, tf.trainable_variables()) self.assertEqual(len(layer_collection.fisher_blocks), 1) fisher_block = list(layer_collection.fisher_blocks.values())[0] self.assertIsInstance(fisher_block, fb.FullyConnectedKFACBasicFB) self.assertEqual(fisher_block.num_registered_towers, num_towers) global_step = tf.train.get_or_create_global_step() opt = optimizer.KfacOptimizer( learning_rate=0.1, cov_ema_decay=0.1, damping=0.1, layer_collection=layer_collection, momentum=0.1) cost = tf.reduce_mean(losses) (cov_update_thunks, inv_update_thunks) = opt.make_vars_and_create_op_thunks() cov_update_op = tf.group(*(thunk() for thunk in cov_update_thunks)) inv_update_op = tf.group(*(thunk() for thunk in inv_update_thunks)) train_op = opt.minimize(cost, global_step=global_step) init = tf.global_variables_initializer() # Run a single training step. with self.test_session() as sess: sess.run(init) sess.run([cov_update_op]) sess.run([inv_update_op]) sess.run([train_op]) def test_multitower_multi_loss_function(self): """Test multitower setup with multiple loss functions. The automatic graph scanner should handle multiple loss functions per tower, as long as they're registered in a consistent order. """ with tf.Graph().as_default(): w_1 = tf.get_variable('w_1', shape=[10, 10]) b_1 = tf.get_variable('b_1', shape=[10]) w_2 = tf.get_variable('w_2', shape=[10, 10]) b_2 = tf.get_variable('b_2', shape=[10]) layer_collection = lc.LayerCollection() layer_collection_manual = lc.LayerCollection() for tower_num in range(5): x = tf.placeholder(tf.float32, shape=(32, 10)) logits_1 = tf.matmul(x, w_1) + b_1 logits_2 = tf.matmul(x, w_2) + b_2 if tower_num == 0: reuse = False else: reuse = True with tf.variable_scope('tower%d' % tower_num, reuse=reuse): for l in [layer_collection, layer_collection_manual]: l.register_categorical_predictive_distribution( logits_1, name='loss_1') l.register_categorical_predictive_distribution( logits_2, name='loss_2') layer_collection_manual.register_fully_connected((w_1, b_1), x, logits_1) layer_collection_manual.register_fully_connected((w_2, b_2), x, logits_2) gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) def test_filter_user_registered_records(self): """Matches containing already registered variables should be removed.""" with tf.Graph().as_default(): tensor_dict, record_list_dict = _build_mock_records() layer_collection = lc.LayerCollection() layer_collection.register_fully_connected( params=(tensor_dict['w'], tensor_dict['b_1']), inputs=tensor_dict['x'], outputs=tensor_dict['pre_bias_0']) user_registered_variables = set() for params in layer_collection.fisher_blocks.keys(): for variable in gs.ensure_sequence(params): user_registered_variables.add(variable) filtered_record_list_dict = gs.filter_user_registered_records( record_list_dict, user_registered_variables) expected_keys = [tensor_dict['b_0']] self.assertDictEqual(filtered_record_list_dict, {k: record_list_dict[k] for k in expected_keys}) def test_filter_grouped_variable_records(self): """Matches violating specified parameter groupings should be removed.""" with tf.Graph().as_default(): tensor_dict, record_list_dict = _build_mock_records() layer_collection = lc.LayerCollection() layer_collection.define_linked_parameters(params=tensor_dict['w']) filtered_record_list_dict = gs.filter_grouped_variable_records( layer_collection, record_list_dict) expected_keys = [tensor_dict['w'], tensor_dict['b_0']] self.assertDictEqual(filtered_record_list_dict, {k: record_list_dict[k] for k in expected_keys}) with tf.Graph().as_default(): tensor_dict, record_list_dict = _build_mock_records() layer_collection = lc.LayerCollection() layer_collection.define_linked_parameters( params=(tensor_dict['w'], tensor_dict['b_0'])) filtered_record_list_dict = gs.filter_grouped_variable_records( layer_collection, record_list_dict) expected_keys = [(tensor_dict['w'], tensor_dict['b_0'])] self.assertDictEqual(filtered_record_list_dict, {k: record_list_dict[k] for k in expected_keys}) def test_filter_subgraph_records(self): """Matches that are strict subgraphs of other matches should be removed.""" with tf.Graph().as_default(): tensor_dict, record_list_dict = _build_mock_records() filtered_record_list_dict = gs.filter_subgraph_records(record_list_dict) expected_keys = [(tensor_dict['w'], tensor_dict['b_0']), (tensor_dict['w'], tensor_dict['b_1'])] self.assertDictEqual(filtered_record_list_dict, {k: record_list_dict[k] for k in expected_keys}) def test_rnn_multi(self): """Test automatic registration on a static RNN. The model tested here is designed for MNIST classification. To classify images using a recurrent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample. """ with tf.Graph().as_default(): dtype = tf.float32 n_input = 28 # MNIST data input (img shape: 28*28) n_timesteps = 28 # timesteps n_hidden = 128 # hidden layer num of features n_classes = 10 # MNIST total classes (0-9 digits) x = tf.placeholder(dtype, [None, n_timesteps, n_input]) y = tf.placeholder(tf.int32, [None]) x_unstack = tf.unstack(x, n_timesteps, 1) w_input = tf.get_variable( 'w_input', shape=[n_input, n_hidden], dtype=dtype) b_input = tf.get_variable('b_input', shape=[n_hidden], dtype=dtype) w_recurrent = tf.get_variable( 'w_recurrent', shape=[n_hidden, n_hidden], dtype=dtype) b_recurrent = tf.get_variable( 'b_recurrent', shape=[n_hidden], dtype=dtype) w_output = tf.get_variable( 'w_output', shape=[n_hidden, n_classes], dtype=dtype) b_output = tf.get_variable('b_output', shape=[n_classes], dtype=dtype) layer_collection_manual = lc.LayerCollection() layer_collection_auto = lc.LayerCollection() a = tf.zeros(tf.convert_to_tensor([tf.shape(x_unstack[0])[0], n_hidden]), dtype=dtype) # Here 'a' are the activations, 's' the pre-activations. a_list = [a] s_input_list = [] s_recurrent_list = [] s_list = [] s_out_list = [] cost = 0.0 for i in range(len(x_unstack)): input_ = x_unstack[i] s_in = tf.matmul(input_, w_input) + b_input s_rec = tf.matmul(a, w_recurrent) + b_recurrent s = s_in + s_rec s_input_list.append(s_in) s_recurrent_list.append(s_rec) s_list.append(s) a = tf.tanh(s) a_list.append(a) s_out = tf.matmul(a, w_output) + b_output s_out_list.append(s_out) if i == len(x_unstack) - 1: labels = y else: labels = tf.zeros([tf.shape(y)[0]], dtype=tf.int32) cost += tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=s_out, labels=labels)) layer_collection_manual.register_categorical_predictive_distribution( s_out) layer_collection_auto.register_categorical_predictive_distribution( s_out) layer_collection_manual.register_fully_connected_multi( (w_input, b_input), x_unstack, s_input_list) layer_collection_manual.register_fully_connected_multi( (w_recurrent, b_recurrent), a_list[:-1], s_recurrent_list) layer_collection_manual.register_fully_connected_multi( (w_output, b_output), a_list[1:], s_out_list) gs.register_layers(layer_collection_auto, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection_manual, layer_collection_auto) def test_graph_search_match_fail(self): """Tests graph search with linked bias tensors. In this code snippet two non adjacent bias tensors are linked together. There is no fisher block in kfac that matches this configuration, so the biases should not be registered. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) # TODO(b/69055612): remove this manual registration once layer_collection # implements register_fully_connected_multi. layer_collection.register_fully_connected( tensor_dict['w'], tensor_dict['x'], tensor_dict['pre_bias_0']) layer_collection.define_linked_parameters((tensor_dict['b_0'], tensor_dict['b_1'])) with self.assertRaises(ValueError) as cm: gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) self.assertIn('in linked group', str(cm.exception)) self.assertIn('was not matched', str(cm.exception)) self.assertIn( str(frozenset([tensor_dict['b_0'], tensor_dict['b_1']])), str(cm.exception)) def test_specify_approximation(self): """Test specifying approximations. If linked parameters are identified along with an approximation, then that approximation should be used when registering those parameters. """ with tf.Graph().as_default(): w_0 = tf.get_variable('w_0', [10, 10]) w_1 = tf.get_variable('w_1', [10, 10]) b_0 = tf.get_variable('b_0', [10]) b_1 = tf.get_variable('b_1', [10]) x_0 = tf.placeholder(tf.float32, shape=(32, 10)) x_1 = tf.placeholder(tf.float32, shape=(32, 10)) pre_bias_0 = tf.matmul(x_0, w_0) pre_bias_1 = tf.matmul(x_1, w_1) out_0 = pre_bias_0 + b_0 # pylint: disable=unused-variable out_1 = pre_bias_1 + b_1 # pylint: disable=unused-variable # Group variables as affine layers. layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters( (w_0, b_0), approximation=lc.APPROX_KRONECKER_NAME) layer_collection.define_linked_parameters( (w_1, b_1), approximation=lc.APPROX_DIAGONAL_NAME) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=32) self.assertIsInstance(layer_collection.fisher_blocks[(w_0, b_0)], fb.FullyConnectedKFACBasicFB) self.assertIsInstance(layer_collection.fisher_blocks[(w_1, b_1)], fb.FullyConnectedDiagonalFB) # Group variables as linear layers and generic parameters. layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters( w_0, approximation=lc.APPROX_DIAGONAL_NAME) layer_collection.define_linked_parameters( b_0, approximation=lc.APPROX_DIAGONAL_NAME) layer_collection.define_linked_parameters( w_1, approximation=lc.APPROX_KRONECKER_NAME) layer_collection.define_linked_parameters( b_1, approximation=lc.APPROX_FULL_NAME) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=32) self.assertIsInstance(layer_collection.fisher_blocks[w_0], fb.FullyConnectedDiagonalFB) self.assertIsInstance(layer_collection.fisher_blocks[b_0], fb.NaiveDiagonalFB) self.assertIsInstance(layer_collection.fisher_blocks[w_1], fb.FullyConnectedKFACBasicFB) self.assertIsInstance(layer_collection.fisher_blocks[b_1], fb.FullFB) def test_specify_approximation_shared_parameters(self): """Test specifying approximations with layers containing shared parameters. If linked parameters are identified along with an approximation, then that approximation should be used when registering those parameters. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters( tensor_dict['w'], approximation=lc.APPROX_KRONECKER_INDEP_NAME) layer_collection.define_linked_parameters( tensor_dict['b_0'], approximation=lc.APPROX_DIAGONAL_NAME) layer_collection.define_linked_parameters( tensor_dict['b_1'], approximation=lc.APPROX_FULL_NAME) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) self.assertIsInstance(layer_collection.fisher_blocks[tensor_dict['w']], fb.FullyConnectedMultiIndepFB) self.assertIsInstance( layer_collection.fisher_blocks[tensor_dict['b_0']], fb.NaiveDiagonalFB) self.assertIsInstance( layer_collection.fisher_blocks[tensor_dict['b_1']], fb.FullFB) def test_tied_weights_untied_bias_registered_weights(self): """Tests that graph search produces right solution on toy model.""" with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss(tensor_dict['out_0']) layer_collection_manual.register_squared_error_loss(tensor_dict['out_1']) layer_collection_manual.register_fully_connected_multi( tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']), (tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1'])) layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=1) layer_collection_manual.register_generic(tensor_dict['b_1'], batch_size=1) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['w'])) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) def test_tied_weights_untied_bias_registered_affine(self): """Test registering linked variables. Registering (w, b_1) as linked variables should not raise an error, since the matches with parameters (w) and (w, b_0) will be filtered out. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss(tensor_dict['out_0']) layer_collection_manual.register_squared_error_loss(tensor_dict['out_1']) layer_collection_manual.register_fully_connected( params=(tensor_dict['w'], tensor_dict['b_1']), inputs=tensor_dict['y'], outputs=tensor_dict['out_1']) layer_collection_manual.register_generic( tensor_dict['b_0'], batch_size=32) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['w'], tensor_dict['b_1'])) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=32) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) def test_tied_weights_untied_bias(self): """Tests that ambiguity in graph raises an error. Graph search will find several possible registrations containing w including (w, b_1) & (w, b_2). Without any instructions in form of linked tensors or manual registration it defaults to registering an error and suggesting that the user register (w) as a linked tensor. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) with self.assertRaises(gs.AmbiguousRegistrationError): gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) def test_tied_weights_untied_bias_registered_bias(self): """Tests that ambiguity in graph raises value error. Graph search will find several possible registrations for tensors. In this registering b_1 as a linked variable will result in an error because there will remain an ambiguity on the other branch of the graph. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['b_1'])) with self.assertRaises(gs.AmbiguousRegistrationError): gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) def test_multi_time_batch_fold(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same linear layer to two different inputs. This tests whether graph search can correctly group them. Also tests whether batch/time folded is correctly registered as fully connected multi fisher blocks. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_0 = tf.matmul(x, w) + b_0 out_1 = tf.matmul(y, w) + b_0 layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss(out_0) layer_collection_manual.register_squared_error_loss(out_1) layer_collection_manual.register_fully_connected_multi( (w, b_0), (x, y), (out_0, out_1), num_uses=2) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=16) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) def test_multiple_weights(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same linear layer to two different inputs. This tests whether graph search can correctly group them. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_0 = tf.matmul(x, w) + b_0 out_1 = tf.matmul(y, w) + b_0 layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_fully_connected_multi((w, b_0), (x, y), (out_0, out_1)) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) def test_subset_weights_manual_registration(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same matmul op to two different inputs followed by adding a bias to one of the inputs. This tests whether graph search can correctly group them. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [10,]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_n1 = tf.matmul(x, w) out_0 = out_n1 + b_0 out_1 = tf.matmul(y, w) layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_fully_connected_multi( w, (x, y), (out_n1, out_1)) layer_collection_manual.register_generic(b_0, batch_size=1) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters(w) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) def mixed_usage_test(self): """Tests that graph search raises error on mixed types usage for tensors. Tensors can be reused in various locations in the tensorflow graph. This occurs regularly in the case of recurrent models or models with parallel graphs. However the tensors must be used for the same operation in each location or graph search should raise an error. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10, 10)) out_0 = tf.matmul(x, w) # pylint: disable=unused-variable out_1 = y + w # pylint: disable=unused-variable layer_collection = lc.LayerCollection() with self.assertRaises(ValueError) as cm: gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) self.assertIn('mixed record types', str(cm.exception)) def test_resource_variable(self): """Ensures that ResourceVariables can be matched.""" with tf.Graph().as_default(): w = tf.get_variable('w', [10, 10], use_resource=True) b = tf.get_variable('b', [10], use_resource=True) x = tf.placeholder(tf.float32, shape=(32, 10)) out_0 = tf.matmul(x, w) + b layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) gs.register_layers(layer_collection, [w, b]) layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss(out_0) layer_collection_manual.register_fully_connected((w, b), x, out_0) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual) self.assertEqual(1, len(layer_collection.get_blocks())) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/keras_callbacks_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for keras/callbacks.py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf from kfac.python.keras import callbacks from kfac.python.keras import optimizers layers = tf.keras.layers _SEED = 1234 class HyperParamTracker(tf.keras.callbacks.Callback): EPOCH, BATCH = range(2) def __init__(self, hyper, record_list, frequency): self.hyper = hyper self.record_list = record_list self.frequency = frequency def on_batch_end(self, batch, logs=None): if self.frequency == HyperParamTracker.BATCH: val = tf.keras.backend.get_value(getattr(self.model.optimizer, self.hyper)) self.record_list.append(val) def on_epoch_end(self, epoch, logs=None): if self.frequency == HyperParamTracker.EPOCH: val = tf.keras.backend.get_value(getattr(self.model.optimizer, self.hyper)) self.record_list.append(val) class CallbacksTest(parameterized.TestCase, tf.test.TestCase): def __init__(self, *args, **kwargs): super(CallbacksTest, self).__init__(*args, **kwargs) self.batch_size = 16 self.num_steps = 20 self.data = np.random.random((self.batch_size*self.num_steps)) self.labels = np.random.random((self.batch_size*self.num_steps)) def setUp(self): super(CallbacksTest, self).setUp() self.model = tf.keras.Sequential([layers.Dense(1, input_shape=(1,))]) tf.random.set_random_seed(_SEED) def testPolynomialDecayValues(self): init_value = 0.01 final_value = 0.0002 power = 0.6 num_decay_steps = 11 num_delay_steps = 3 opt = tf.keras.optimizers.Adam(learning_rate=init_value) self.model.compile(opt, 'mse') lr_list = [] cbs = [ callbacks.PolynomialDecay(hyperparameter='learning_rate', init_value=init_value, final_value=final_value, power=power, num_decay_steps=num_decay_steps, num_delay_steps=num_delay_steps, verbose=1), HyperParamTracker('learning_rate', lr_list, HyperParamTracker.BATCH) ] self.model.fit( self.data, self.labels, batch_size=self.batch_size, callbacks=cbs) expected_list = [init_value] * num_delay_steps + [ (init_value - final_value) * (1 - min(i, num_decay_steps) / float(num_decay_steps)) ** power + final_value for i in range(self.num_steps - num_delay_steps) ] self.assertAllClose(lr_list, expected_list) def testExponentialDampingValuesWithDecayRate(self): init_value = 0.01 decay_rate = 0.3 num_decay_steps = 4 num_delay_steps = 3 opt = optimizers.Kfac( learning_rate=0.01, damping=init_value, model=self.model, loss='mse') self.model.compile(opt, 'mse') damping_list = [] cbs = [ callbacks.ExponentialDecay(hyperparameter='damping', init_value=init_value, decay_rate=decay_rate, num_decay_steps=num_decay_steps, num_delay_steps=num_delay_steps, verbose=1), HyperParamTracker('damping', damping_list, HyperParamTracker.BATCH) ] self.model.fit( self.data, self.labels, batch_size=self.batch_size, callbacks=cbs) expected_list = [init_value] * num_delay_steps + [ init_value * decay_rate ** min(i, num_decay_steps) for i in range(self.num_steps - num_delay_steps) ] self.assertAllClose(damping_list, expected_list) def testExponentialDampingValuesWithFinalValue(self): init_value = 0.01 final_value = 0.0001 num_decay_steps = 4 num_delay_steps = 3 opt = optimizers.Kfac( learning_rate=0.01, damping=init_value, model=self.model, loss='mse') self.model.compile(opt, 'mse') damping_list = [] cbs = [ callbacks.ExponentialDecay(hyperparameter='damping', init_value=init_value, final_value=final_value, num_decay_steps=num_decay_steps, num_delay_steps=num_delay_steps, verbose=1), HyperParamTracker('damping', damping_list, HyperParamTracker.BATCH) ] self.model.fit( self.data, self.labels, batch_size=self.batch_size, callbacks=cbs) expected_list = [init_value] * num_delay_steps + [ init_value * (final_value/init_value) ** (min(i, num_decay_steps)*1./num_decay_steps) for i in range(self.num_steps - num_delay_steps) ] self.assertAllClose(damping_list, expected_list) self.assertNear(damping_list[-1], final_value, err=1e-5) def testExponentialDampingValuesWithFinalValueAndRate(self): init_value = 0.01 final_value = 0.0001 decay_rate = 0.6 num_delay_steps = 3 opt = optimizers.Kfac( learning_rate=0.01, damping=init_value, model=self.model, loss='mse') self.model.compile(opt, 'mse') damping_list = [] cbs = [ callbacks.ExponentialDecay(hyperparameter='damping', init_value=init_value, final_value=final_value, decay_rate=decay_rate, num_delay_steps=num_delay_steps, verbose=1), HyperParamTracker('damping', damping_list, HyperParamTracker.BATCH) ] self.model.fit( self.data, self.labels, batch_size=self.batch_size, callbacks=cbs) expected_list = [init_value] * num_delay_steps + [ max((init_value * decay_rate ** i), final_value) for i in range(self.num_steps - num_delay_steps) ] self.assertAllClose(damping_list, expected_list) self.assertNear(damping_list[-1], final_value, err=1e-5) @parameterized.named_parameters( ('_Exponential', 'damping', callbacks.ExponentialDecay(hyperparameter='damping', init_value=0.01, decay_rate=0.3, num_decay_steps=30)), ('_Polynomial', 'learning_rate', callbacks.PolynomialDecay(hyperparameter='learning_rate', init_value=0.001, final_value=0.002, power=0.6, num_decay_steps=30))) def testTrainHistory(self, hyper, callback): opt = optimizers.Kfac(learning_rate=0.001, damping=0.01, model=self.model, loss='mse', num_burnin_steps=5) self.model.compile(opt, 'mse') lst = [] cbs = [callback, HyperParamTracker(hyper, lst, HyperParamTracker.EPOCH)] hist = self.model.fit(self.data, self.labels, batch_size=self.batch_size, epochs=3, callbacks=cbs) self.assertAllClose(lst, hist.history[hyper]) def testDampingDecayFailsWithNoDamping(self): with self.assertRaisesRegex(ValueError, '.*must have a "damping".*'): self.model.compile('adam', 'mse') cb = callbacks.ExponentialDecay(hyperparameter='damping', init_value=0.01, decay_rate=0.3, num_decay_steps=4) self.model.fit(self.data, self.data, callbacks=[cb]) def testExponentialDampingFailsNoRateOrFinalValue(self): with self.assertRaisesRegex(ValueError, '.*must specify exactly two of.*'): callbacks.ExponentialDecay(hyperparameter='damping', init_value=0.01) def testExponentialDampingFailsWithAllOptionals(self): with self.assertRaisesRegex(ValueError, '.*must specify exactly two of.*'): callbacks.ExponentialDecay(hyperparameter='learning_rate', init_value=0.01, final_value=0.001, decay_rate=0.99, num_decay_steps=50) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/keras_optimizers_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for keras/optimizers.py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import json from absl.testing import parameterized from tensorflow.python.keras import backend import numpy as np import tensorflow.compat.v1 as tf from tensorflow.python.util import serialization from kfac.python.keras import optimizers from kfac.python.keras import utils layers = tf.keras.layers losses = tf.keras.losses _SEED = 1234 # TODO(b/135916953): Use TensorFlow test_utils instead of below helpers. def _get_synthetic_mnist_dataset(train_size=64, test_size=16): num_classes = 10 img_rows, img_cols = 28, 28 rng = np.random.RandomState(_SEED) num_examples = train_size + test_size images = rng.rand(num_examples, img_rows * img_cols).astype(np.float32) images = np.reshape(images, [num_examples, img_rows, img_cols, 1]) labels = rng.randint(num_classes, size=num_examples) one_hot_labels = np.eye(num_classes)[labels].astype(np.float32) return ((images[:train_size], one_hot_labels[:train_size]), (images[train_size:], one_hot_labels[train_size:])) def _get_synthetic_mnist_train_tensors( train_size=64, batch_size=10, drop_remainder=False): (x_train, y_train), _ = _get_synthetic_mnist_dataset(train_size=train_size) dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) dataset = dataset.repeat().batch(batch_size, drop_remainder=drop_remainder) return dataset.make_one_shot_iterator().get_next() def _generate_target_fn(num_examples): """Generated a random 2d target function for regression. Args: num_examples: The number of evenly spaced examples along the function to generate. Returns: A tuple of the x tensor and the y tensor for the generated function. """ inds = np.arange(num_examples) x = np.sort(np.random.rand(num_examples) - 0.5) x = np.expand_dims(x, axis=1) y = np.transpose(x) dist = np.square(x - y) # Should be scipy cdist(x, x, metric='sqeuclidean') k = np.exp(-dist / 0.01) k += np.eye(k.shape[0]) * 1e-6 l = np.linalg.cholesky(k) random_y = np.random.randn(x.shape[0], 1) y = np.dot(l, random_y) + np.random.randn(x.shape[0], 1) * 1e-1 return x[inds, :], y[inds, :] def _generate_regression_data(num_eg, num_train_eg): x_all, y_all = _generate_target_fn(num_eg) x_all = x_all.astype(np.float32) y_all = y_all.astype(np.float32) inds = np.arange(num_eg) np.random.shuffle(inds) x_train = x_all[inds[:num_train_eg]] y_train = y_all[inds[:num_train_eg]] x_test = x_all[inds[num_train_eg:]] y_test = y_all[inds[num_train_eg:]] return (x_train, y_train), (x_test, y_test) def _simple_mlp(): return tf.keras.Sequential([ layers.Dense(32, input_shape=(1,), activation='tanh'), layers.Dense(32, activation='tanh'), layers.Dense(1) ]) def _mnist_model(use_bias=True, use_separate_activation=True): """A complex architecture to test the variable registration. This model is not intended to be a "good" mnist classifier. It uses Lambda layers, concats, and separate branches to test effectively. Args: use_bias: boolean. Whether all the layers use a bias term or not. use_separate_activation: boolean. Whether the layers have the activation within the layer or use a separate activation layer. Returns: A Keras model containing the mnist classifier. """ activation = 'linear' if use_separate_activation else 'relu' output_activation = 'linear' if use_separate_activation else 'softmax' inp = layers.Input(shape=(28, 28, 1)) branch1 = layers.Lambda(lambda x: tf.squeeze(x, -1))(inp) branch1 = layers.Conv1D(3, kernel_size=7, activation=activation, use_bias=use_bias)(branch1) if use_separate_activation: branch1 = layers.Activation('relu')(branch1) branch1 = layers.GlobalMaxPool1D()(branch1) branch2 = layers.Conv2D(16, kernel_size=(3, 3), activation=activation, use_bias=use_bias)(inp) if use_separate_activation: branch2 = layers.Activation('relu')(branch2) branch2 = layers.MaxPooling2D(pool_size=(4, 4))(branch2) branch2 = layers.Flatten()(branch2) branch2 = layers.Dense(20, use_bias=use_bias)(branch2) if use_separate_activation: branch2 = layers.Activation('relu')(branch2) out = layers.concatenate([branch1, branch2]) out = layers.Dense(10, use_bias=use_bias, activation=output_activation)(out) if use_separate_activation: out = layers.Activation('softmax')(out) return tf.keras.Model(inputs=inp, outputs=out) def _train_model(data, model, loss, lr=0.001, damping=0.001, batch_size=32, epochs=1, loss_weights=None): """Compiles and fits model to data and returns trainging results. Args: data: Tuple of numpy arrays shaped ((x_train, y_train), (x_test, y_test)). model: Uncompiled Keras model with inputs/output shapes matching the data. loss: tf.keras.losses loss function or serialized (string) loss function. lr: Learning rate for optimizer. damping: Damping parameter for KFAC. batch_size: Batch size used for training. epochs: Number of training epochs. loss_weights: List of weights or dict mapping layer names to loss function weight. Returns: A History object. Calling History.history gives you a dictionary with training and validation results. """ (x_train, y_train), valid_data = data opt = optimizers.Kfac(learning_rate=lr, damping=damping, model=model, loss=loss, loss_weights=loss_weights) model.compile(opt, loss, loss_weights=loss_weights) return model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=valid_data, verbose=0) class KfacOptimizerTest(parameterized.TestCase, tf.test.TestCase): def __init__(self, *args, **kwargs): super(KfacOptimizerTest, self).__init__(*args, **kwargs) self._mnist_data = _get_synthetic_mnist_dataset() def setUp(self): super(KfacOptimizerTest, self).setUp() tf.random.set_random_seed(_SEED) np.random.seed(_SEED) def testFunctionalInstantiation(self): inputs = layers.Input(shape=(3,)) x = layers.Dense(4, activation=tf.nn.relu)(inputs) outputs = layers.Dense(5, activation=tf.nn.softmax)(x) model = tf.keras.Model(inputs=inputs, outputs=outputs) optimizers.Kfac(learning_rate=0.002, damping=0.04, model=model, loss='binary_crossentropy') def testSequentialInstantiation(self): model = tf.keras.Sequential([ layers.Conv2D(7, (3, 3), input_shape=(28, 28, 3)), layers.Activation('relu'), layers.Conv2D(13, (3, 3), activation='relu'), layers.GlobalMaxPool2D(), layers.Activation('softmax') ]) optimizers.Kfac(learning_rate=0.03, damping=0.00007, model=model, loss='binary_crossentropy') def testInstantiationWithLayerCollection(self): model = _simple_mlp() lc = utils.get_layer_collection(model, 'mse') opt = optimizers.Kfac( learning_rate=0.1, damping=0.2, layer_collection=lc) model.compile(optimizer=opt, loss='mse') opt.get_updates(model.total_loss, model.trainable_weights) def testRNNFails(self): model = tf.keras.Sequential() model.add(layers.Embedding(43, 128)) model.add(layers.LSTM(128, dropout=0.2, recurrent_dropout=0.2)) model.add(layers.Dense(1, activation='sigmoid')) opt = optimizers.Kfac(learning_rate=0.003, damping=0.003, model=model, loss='binary_crossentropy') with self.assertRaisesRegex(ValueError, '.*lstm.* has more than one parent tensor.$'): opt._create_optimizer() @parameterized.named_parameters(('BiasCombinedActivation', True, True), ('BiasSeparateActivation', True, False), ('NoBiasCombinedActivation', False, True), ('NoBiasSeparateActivation', False, False)) def testBiasAndActivations(self, use_bias, use_separate_activation): model = _mnist_model(use_bias=use_bias, use_separate_activation=use_separate_activation) _train_model(self._mnist_data, model, 'categorical_crossentropy') def testRegression(self): hist = _train_model( _generate_regression_data(200, 150), _simple_mlp(), 'mse', epochs=5) val_loss = hist.history['val_loss'] self.assertGreater(val_loss[0], val_loss[-1]) def testClipNormFails(self): with self.assertRaises(ValueError): optimizers.Kfac(learning_rate=0.001, damping=0.001, model=_simple_mlp(), loss='mse', clipnorm=0.1) def testClipValueFails(self): with self.assertRaises(ValueError): optimizers.Kfac(learning_rate=0.01, damping=0.01, model=_simple_mlp(), loss='mse', clipvalue=0.1) def testLossTensor(self): loss_tensor = tf.convert_to_tensor(2.0) opt = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=_simple_mlp(), loss='mse', loss_tensor=loss_tensor) self.assertEqual(opt.optimizer._loss_tensor, loss_tensor) def testArgsKwargs(self): """Test if kwargs are correctly forwarded to tensorflow_kfac.""" kwargs = { 'learning_rate': 3.0, 'damping': 5.0, 'momentum': 7.0, 'min_damping': 9.0, 'num_burnin_steps': 11, 'invert_every': 13, 'fisher_approx': { layers.Dense: 'kron_in_diag', 'dense_1': 'kron_both_diag' }, } model = _simple_mlp() opt = optimizers.Kfac(model=model, loss='mse', **kwargs) self.assertEqual(opt.optimizer._min_damping, kwargs['min_damping']) self.assertEqual(opt.optimizer._num_burnin_steps, kwargs['num_burnin_steps']) self.assertEqual(opt.optimizer._invert_every, kwargs['invert_every']) fisher_block_0 = opt.optimizer.layers.fisher_blocks[model.layers[0].weights] self.assertTrue(fisher_block_0._diagonal_approx_for_input) self.assertFalse(fisher_block_0._diagonal_approx_for_output) fisher_block_1 = opt.optimizer.layers.fisher_blocks[model.layers[1].weights] self.assertTrue(fisher_block_1._diagonal_approx_for_input) self.assertTrue(fisher_block_1._diagonal_approx_for_output) with tf.Session() as sess: # In Keras, typically you do not use sessions directly. When you use a # Keras component, the required variables are initialized for you because # they are tracked. Here, we explicitly run the variables in a session so # they must be initialized. sess.run(tf.global_variables_initializer()) self.assertEqual(sess.run(opt.optimizer.momentum), kwargs['momentum']) self.assertEqual(sess.run(opt.optimizer.learning_rate), kwargs['learning_rate']) self.assertEqual(sess.run(opt.optimizer.damping), kwargs['damping']) def testConfig(self): fisher_approx = {layers.Dense: 'kron_in_diag', 'dense_1': 'kron_both_diag'} kwargs = { 'loss': 'mse', 'momentum': 7.0, 'num_burnin_steps': 11.0, 'min_damping': 9.0, 'invert_every': 13, 'fisher_approx': fisher_approx, 'seed': 12, } opt = optimizers.Kfac( learning_rate=3.0, damping=5.0, model=_simple_mlp(), **kwargs) opt.learning_rate = 23.0 opt.damping = 27.0 config = opt.get_config() self.assertEqual(config['learning_rate'], 23.0) self.assertEqual(config['damping'], 27.0) dense_approx = fisher_approx.pop(layers.Dense) fisher_approx[utils._CLASS_NAME_PREFIX + 'Dense'] = dense_approx for key, val in kwargs.items(): self.assertEqual(config[key], val) # Below is how Keras's model.save saves the configs. If the config is not # serializable, it will throw a TypeError or OverflowError. json.dumps(config, default=serialization.get_json_type).encode('utf8') @parameterized.named_parameters(('_LossName', {'loss': 'mse'}), ('_LossFunction', {'loss': losses.MSE})) def testFromConfig(self, kwargs_updates): kwargs = { 'learning_rate': 3.0, 'damping': 5.0, 'momentum': 7.0, 'min_damping': 9.0, 'num_burnin_steps': 11, 'invert_every': 13, 'fisher_approx': { layers.Dense: 'kron_in_diag', 'dense_1': 'kron_both_diag' }, } kwargs.update(kwargs_updates) opt = optimizers.Kfac(model=_simple_mlp(), **kwargs) config = opt.get_config() config['name'] = 'diff_scope_name' opt2 = optimizers.Kfac.from_config(config) config2 = opt2.get_config() config2.pop('name') config.pop('name') self.assertEqual(config, config2) # Below is how Keras's model.save saves the configs. If the config is not # serializable, it will throw a TypeError or OverflowError. json.dumps(config, default=serialization.get_json_type).encode('utf8') json.dumps(config2, default=serialization.get_json_type).encode('utf8') @parameterized.named_parameters(('_Tensor', tf.convert_to_tensor), ('_Float', float)) def testGettingHyper(self, hyper_ctor): kwarg_values = {'learning_rate': 3.0, 'damping': 20.0, 'momentum': 13.0} kwargs = {k: hyper_ctor(v) for k, v in kwarg_values.items()} opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs) get_value = backend.get_value tf_opt = opt.optimizer with self.subTest(name='MatchesFloat'): for name, val in kwarg_values.items(): self.assertEqual(get_value(getattr(opt, name)), val) with self.subTest(name='MatchesTfOpt'): self.assertEqual(get_value(opt.lr), get_value(tf_opt.learning_rate)) self.assertEqual(get_value(opt.damping), get_value(tf_opt.damping)) self.assertEqual(get_value(opt.momentum), get_value(tf_opt.momentum)) def testGettingVariableHyperFails(self): self.skipTest('This is not fixed in TF 1.14 yet.') opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', learning_rate=tf.Variable(0.1), damping=tf.Variable(0.1)) with self.assertRaisesRegex(tf.errors.FailedPreconditionError, '.*uninitialized.*'): backend.get_value(opt.learning_rate) @parameterized.named_parameters( (('_' + name, name, float(val+1)) for val, name in enumerate(optimizers._MUTABLE_HYPER_PARAMS))) def testSetTFVariableHyper(self, name, val): kwargs = {'learning_rate': 0.01, 'damping': 0.001} kwargs[name] = tf.Variable(45.0) opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs) setattr(opt, name, val) with self.subTest(name='AssignedCorrectly'): self.assertEqual(backend.get_value(getattr(opt, name)), val) if hasattr(opt.optimizer, name): self.assertEqual(backend.get_value(getattr(opt.optimizer, name)), val) with self.subTest(name='SetError'): with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'): setattr(opt, name, tf.convert_to_tensor(2.0)) with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'): setattr(opt, name, tf.Variable(2.0)) @parameterized.named_parameters( (('_' + name, name, float(val + 1)) for val, name in enumerate(optimizers._MUTABLE_HYPER_PARAMS))) def testSetFloatHyper(self, name, val): kwargs = {'learning_rate': 0.01, 'damping': 0.001} kwargs[name] = 45.0 opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs) setattr(opt, name, val) with self.subTest(name='AssignedCorrectly'): self.assertEqual(backend.get_value(getattr(opt, name)), val) if hasattr(opt.optimizer, name): self.assertEqual(backend.get_value(getattr(opt.optimizer, name)), val) with self.subTest(name='SetError'): with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'): setattr(opt, name, tf.convert_to_tensor(2.0)) with self.assertRaisesRegex(ValueError, 'Dynamic reassignment only.*'): setattr(opt, name, tf.Variable(2.0)) @parameterized.named_parameters( (('_' + name, name, float(val + 1)) for val, name in enumerate(optimizers._MUTABLE_HYPER_PARAMS))) def testModifyingTensorHypersFails(self, name, val): kwargs = {'learning_rate': 3.0, 'damping': 5.0, 'momentum': 7.0} kwargs[name] = tf.convert_to_tensor(val) opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs) with self.subTest(name='AssignedCorrectly'): self.assertEqual(backend.get_value(getattr(opt, name)), val) with self.subTest(name='RaisesError'): with self.assertRaisesRegex(AttributeError, "Can't set attribute: {}".format(name)): setattr(opt, name, 17) def testLRBackwardsCompatibility(self): """This tests learning rate getting/setting used by old Keras callbacks.""" opt = optimizers.Kfac( learning_rate=3.0, damping=5.0, model=_simple_mlp(), loss='mse') self.assertEqual(backend.get_value(opt.lr), 3.0) self.assertEqual(backend.get_value(opt.learning_rate), 3.0) opt.lr = 7.0 self.assertEqual(backend.get_value(opt.lr), 7.0) self.assertEqual(backend.get_value(opt.learning_rate), 7.0) backend.set_value(opt.lr, 9.0) self.assertEqual(backend.get_value(opt.lr), 9.0) self.assertEqual(backend.get_value(opt.learning_rate), 9.0) backend.set_value(opt.learning_rate, 11.0) self.assertEqual(backend.get_value(opt.lr), 11.0) self.assertEqual(backend.get_value(opt.learning_rate), 11.0) def testMultipleLossTraining(self): inp = layers.Input(shape=(28, 28, 1)) branch1 = layers.Conv2D(13, 7, activation='relu')(inp) branch1 = layers.GlobalMaxPool2D()(branch1) branch1 = layers.Dense(1, name='path1')(branch1) branch2 = layers.Conv2D(16, 3, activation='relu')(inp) branch2 = layers.MaxPooling2D(pool_size=(4, 4))(branch2) branch2 = layers.Flatten()(branch2) branch2 = layers.Dense(9, name='path2')(branch2) model = tf.keras.Model(inputs=inp, outputs=[branch1, branch2]) loss = {'path1': 'binary_crossentropy', 'path2': 'categorical_crossentropy'} loss_weights = {'path1': 0.1, 'path2': 0.9} (x, y), (valid_x, valid_y) = _get_synthetic_mnist_dataset() y1, y2 = y[:, 0:1], y[:, 1:] valid_y1, valid_y2 = valid_y[:, 0:1], valid_y[:, 1:] data = (x, (y1, y2)), (valid_x, (valid_y1, valid_y2)) _train_model(data, model, loss, loss_weights=loss_weights) @parameterized.named_parameters(('_LossName', 'categorical_crossentropy'), ('_LossFunction', losses.binary_crossentropy)) def testRegisterLayersWithModel(self, loss): model = _mnist_model() opt = optimizers.Kfac(learning_rate=0.01, damping=0.001) opt.register_layers(model=model, loss=loss) model.compile(optimizer=opt, loss=loss) opt.get_updates(model.total_loss, model.trainable_weights) def testRegisterLayersWithLayerCollection(self): model, loss = _mnist_model(), 'categorical_crossentropy' lc = utils.get_layer_collection(model, loss) opt = optimizers.Kfac(learning_rate=0.01, damping=0.001) opt.register_layers(layer_collection=lc) model.compile(optimizer=opt, loss=loss) opt.get_updates(model.total_loss, model.trainable_weights) @parameterized.named_parameters(('_LossName', 'categorical_crossentropy'), ('_LossFunction', losses.binary_crossentropy)) def testRegisterLayersCompiledModel(self, loss): opt = optimizers.Kfac(learning_rate=0.01, damping=0.001) model = _mnist_model() model.compile(optimizer=opt, loss=loss) opt.register_layers(model=model) model.compile(optimizer=opt, loss=loss) opt.get_updates(model.total_loss, model.trainable_weights) def testTrainWithoutCreatingOptimizerFails(self): with self.assertRaisesRegex(ValueError, '.*provide a model with a loss.*'): opt = optimizers.Kfac(learning_rate=0.01, damping=0.001) model = _mnist_model() model.compile(optimizer=opt, loss='categorical_crossentropy') grads_vars = opt.get_gradients(model.total_loss, model.trainable_weights) opt.apply_gradients(grads_vars) def testEmptyCreateKfacOptimizerFails(self): with self.assertRaisesRegex(ValueError, '.*provide a model with a loss.*'): opt = optimizers.Kfac(learning_rate=0.01, damping=0.001) opt._create_optimizer() def testSeed(self): opt = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=_simple_mlp(), loss='mse', seed=4321) lc = opt.optimizer.layers self.assertEqual(lc._loss_dict['squared_error_loss'][0]._default_seed, 4321) def testNewOptSameVarScope(self): model = _simple_mlp() opt = optimizers.Kfac( learning_rate=0.01, damping=0.01, model=model, loss='mse') opt._create_optimizer() opt2 = optimizers.Kfac( learning_rate=0.02, damping=0.03, model=model, loss='mse') opt2._create_optimizer() def testGetSetWeights(self): def model_maker(): return tf.keras.Sequential([layers.Dense(2, input_shape=(3,))]) x = np.random.random((1, 3)) y = np.random.random((1, 2)) loss = 'mse' model = model_maker() opt = optimizers.Kfac(learning_rate=0.01, damping=0.1, model=model, loss=loss, seed=1234) model.compile(optimizer=opt, loss=loss) model.train_on_batch(x, y) opt_weights = opt.get_weights() self.assertEqual(1, opt_weights[0]) # iterations self.assertEqual(1, opt_weights[6]) # counter self.assertEqual(0, opt_weights[7]) # burn in counter config = opt.get_config() config['name'] = 'diff_name' opt2 = optimizers.Kfac.from_config(config) model2 = model_maker() model2.compile(optimizer=opt2, loss=loss) opt2.register_layers(model=model2) # Set weights should only work after a call to get_updates/apply_gradients. x = np.random.random((1, 3)) y = np.random.random((1, 2)) model2.train_on_batch(x, y) opt2.set_weights(opt_weights) for w1, w2 in zip(opt_weights, opt2.get_weights()): self.assertAllClose(w1, w2) model2.set_weights(model.get_weights()) x = np.random.random((1, 3)) y = np.random.random((1, 2)) model.train_on_batch(x, y) model2.train_on_batch(x, y) for w1, w2 in zip(opt.get_weights(), opt2.get_weights()): self.assertAllClose(w1, w2) @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False)) def testTrainModelWithNormalization(self, has_shift): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 1)), layers.BatchNormalization(center=has_shift, fused=False), layers.Conv2D(23, 3), layers.LayerNormalization(center=has_shift), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax') ]) (x_train, y_train), _ = _get_synthetic_mnist_dataset() approx = {layers.LayerNormalization: 'full'} loss = 'categorical_crossentropy' opt = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=model, loss=loss, fisher_approx=approx) model.compile(opt, loss) return model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0) @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False)) def testTrainModelWithFusedBN(self, has_shift): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 1)), layers.BatchNormalization(center=has_shift, fused=True), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax') ]) (x_train, y_train), _ = _get_synthetic_mnist_dataset() loss = 'categorical_crossentropy' opt = optimizers.Kfac( learning_rate=0.01, damping=0.01, model=model, loss=loss) model.compile(opt, loss) return model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0) @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False)) def testTrainModelWithFusedBNAndLearningPhase(self, has_shift): tf.keras.backend.set_learning_phase(1) model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 1)), layers.BatchNormalization(center=has_shift, fused=True), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax') ]) (x_train, y_train), _ = _get_synthetic_mnist_dataset() loss = 'categorical_crossentropy' opt = optimizers.Kfac( learning_rate=0.01, damping=0.01, model=model, loss=loss) model.compile(opt, loss) return model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0) @parameterized.named_parameters(('_WithShape', {'input_shape': (28, 28, 1)}), ('_WithoutShape', {})) def testCustomTrainingLoopSequential(self, input_conv_kwargs): # Without the input_shape the only inbound node is the correct one, with the # input_shape there are two, and we want the second one. model = tf.keras.Sequential([ layers.Conv2D(13, 5, **input_conv_kwargs), layers.BatchNormalization(fused=False), layers.Conv2D(23, 3), layers.LayerNormalization(), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax', name='output_test') ]) x, y = _get_synthetic_mnist_train_tensors(batch_size=10) model_input = tf.keras.Input(tensor=x) output = model(model_input) loss = tf.keras.losses.binary_crossentropy(output, y) optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=model, loss='binary_crossentropy') train_op = optimizer.minimize(loss, var_list=model.trainable_weights) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(3): sess.run([train_op]) def testCustomTrainingLoopFunctionalInpTensor(self): # This case should work trivially--the only inbound node is the correct one. x, y = _get_synthetic_mnist_train_tensors(batch_size=10) # Build Model inp = tf.keras.Input(tensor=x) x = layers.Conv2D(13, 5)(inp) x = layers.BatchNormalization(fused=False)(x) x = layers.Conv2D(23, 3)(x) x = layers.LayerNormalization()(x) x = layers.GlobalMaxPool2D()(x) out = layers.Dense(10, activation='softmax', name='output_test')(x) model = tf.keras.Model(inputs=inp, outputs=out) loss = tf.keras.losses.binary_crossentropy(model.output, y) optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=model, loss='binary_crossentropy') train_op = optimizer.minimize(loss, var_list=model.trainable_weights) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(3): sess.run([train_op]) def testCustomTrainingLoopFunctionalInpShape(self): # We need to ensure correct inbound node is used for layer collection. x, y = _get_synthetic_mnist_train_tensors(batch_size=10) model_input = tf.keras.Input(tensor=x) # Build Model inp = tf.keras.Input(shape=(28, 28, 1)) x = layers.Conv2D(13, 5)(inp) x = layers.BatchNormalization(fused=True)(x) x = layers.Activation('relu')(x) x = layers.Conv2D(23, 3)(x) x = layers.LayerNormalization()(x) x = layers.GlobalMaxPool2D()(x) out = layers.Dense(10, activation='softmax', name='output_test')(x) model = tf.keras.Model(inputs=inp, outputs=out) output = model(model_input) loss = tf.keras.losses.binary_crossentropy(output, y) optimizer = optimizers.Kfac(damping=0.01, learning_rate=0.01, model=model, loss='binary_crossentropy') train_op = optimizer.minimize(loss, var_list=model.trainable_weights) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(3): sess.run([train_op]) def testCustomTrainingLoopMakeOptimizerBeforeModelCall(self): # We defer the creation of the layer_collection to the minimize call for # this situation, because if we make the layer_collection immediately it # will capture the wrong inbound node. model = tf.keras.Sequential([ layers.Conv2D(13, 5), layers.BatchNormalization(fused=False), layers.Conv2D(23, 3), layers.LayerNormalization(), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax', name='output_test') ]) optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=model, loss='binary_crossentropy') x, y = _get_synthetic_mnist_train_tensors(batch_size=10) model_input = tf.keras.Input(tensor=x) output = model(model_input) loss = tf.keras.losses.binary_crossentropy(output, y) train_op = optimizer.minimize(loss, var_list=model.trainable_weights) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(3): sess.run([train_op]) def testCustomTrainingUnwrappedTensorFails(self): # This test does not test our implementation, but is here so if Keras ever # adds functionality to support raw tensors as Nodes, this test will fail # and we can remove the restriction from our documentation. model = _simple_mlp() dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat().batch(10) x, y = dataset.make_one_shot_iterator().get_next() pred = model(x) loss = tf.keras.losses.binary_crossentropy(pred, y) optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=model, loss='binary_crossentropy') train_op = optimizer.minimize(loss, var_list=model.trainable_weights) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) with self.assertRaisesRegex(tf.errors.InvalidArgumentError, '.*You must feed a value for placeholder.*'): sess.run([train_op]) def testTrainingNestedModel(self): inputs = tf.keras.Input(shape=(1,)) y1 = _simple_mlp()(inputs) y2 = _simple_mlp()(inputs) y3 = _simple_mlp()(inputs) outputs = layers.average([y1, y2, y3]) ensemble_model = tf.keras.Model(inputs=inputs, outputs=outputs) optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=ensemble_model, loss='binary_crossentropy') ensemble_model.compile(optimizer, 'binary_crossentropy') dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat().batch(10) x, y = dataset.make_one_shot_iterator().get_next() ensemble_model.train_on_batch(x, y) def testCustomTrainLoopNestedModel(self): inputs = tf.keras.Input(shape=(1,)) y1 = _simple_mlp()(inputs) y2 = _simple_mlp()(inputs) y3 = _simple_mlp()(inputs) outputs = layers.average([y1, y2, y3]) ensemble_model = tf.keras.Model(inputs=inputs, outputs=outputs) dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat().batch(10) x, y = dataset.make_one_shot_iterator().get_next() x = layers.Input(tensor=x) optimizer = optimizers.Kfac(learning_rate=0.01, damping=0.01, model=ensemble_model, loss='binary_crossentropy') pred = ensemble_model(x) loss = tf.keras.losses.binary_crossentropy(pred, y) train_op = optimizer.minimize( loss, var_list=ensemble_model.trainable_weights) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) sess.run([train_op]) @parameterized.named_parameters( ('_NoKwargs', {'norm_constraint'}, {}), ('_MomentumNormKwargs', set(), {'momentum': 1, 'norm_constraint': 2}), ('_QModel', {'momentum', 'learning_rate', 'norm_constraint'}, {'momentum': None, 'momentum_type': 'qmodel', 'learning_rate': None}), ('_AdaptiveDamping', {'damping', 'norm_constraint'}, {'adapt_damping': True, 'damping_adaptation_interval': 20})) def testMutableHypers(self, not_mutable, kwargs_update): kwargs = {'learning_rate': 0.01, 'damping': 0.001} kwargs.update(kwargs_update) opt = optimizers.Kfac(model=_simple_mlp(), loss='mse', **kwargs) mutable = optimizers._MUTABLE_HYPER_PARAMS - not_mutable self.assertEqual(set(opt.mutable_hyperparameters), mutable) def testPositionalArgsFail(self): with self.assertRaisesRegex(ValueError, 'Do not pass positional arguments.*'): optimizers.Kfac(0.1, 0.1, model=_simple_mlp(), loss='mse') def testSettingName(self): model = _simple_mlp() optimizer = optimizers.Kfac(damping=0.01, learning_rate=0.01, model=model, loss='mse') optimizer.name = 'new_name' self.assertEqual(optimizer._name, 'new_name') self.assertEqual(optimizer.get_config()['name'], 'new_name') self.assertEqual(optimizer._kfac_kwargs['name'], 'new_name') model.compile(optimizer, 'mse') model._make_train_function() with self.assertRaisesRegex(ValueError, '.*after the variables are created.*'): optimizer.name = 'another_name' @parameterized.named_parameters( ('_AdaptDamping', {'adapt_damping': True, 'learning_rate': 0.1}), ('_Adaptive', {'adaptive': True, 'qmodel_update_rescale': 0.01})) def testAdaptiveModelFit(self, adaptive_kwargs): rands = lambda: np.random.random((100, 1)).astype(np.float32) dataset = tf.data.Dataset.from_tensor_slices((rands(), rands())) dataset = dataset.repeat().batch(10, drop_remainder=True) train_batch = dataset.make_one_shot_iterator().get_next() model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))]) loss = 'mse' train_batch = dataset.make_one_shot_iterator().get_next() optimizer = optimizers.Kfac(damping=10., train_batch=train_batch, model=model, loss=loss, **adaptive_kwargs) model.compile(optimizer, loss) model.fit(train_batch, steps_per_epoch=10, epochs=1) @parameterized.named_parameters(('_Fused', True), ('_NotFused', False)) def testAdaptiveModelFitBatchnorm(self, is_fused): train_batch = _get_synthetic_mnist_train_tensors(drop_remainder=True) model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28,28,1)), layers.BatchNormalization(fused=is_fused), layers.Conv2D(23, 3), layers.LayerNormalization(), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax', name='output_test') ]) loss = 'categorical_crossentropy' optimizer = optimizers.Kfac(damping=10., adaptive=True, train_batch=train_batch, model=model, loss=loss) model.compile(optimizer, loss) model.train_on_batch(x=train_batch[0], y=train_batch[1]) def testInferredBatchSize(self): dataset = tf.data.Dataset.from_tensors(([1.], [1.])) dataset = dataset.repeat().batch(11, drop_remainder=True) train_batch = dataset.make_one_shot_iterator().get_next() model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))]) loss = 'mse' train_batch = dataset.make_one_shot_iterator().get_next() optimizer = optimizers.Kfac(damping=10., train_batch=train_batch, model=model, adaptive=True, loss=loss, qmodel_update_rescale=0.01) model.compile(optimizer, loss) model.train_on_batch(train_batch[0], train_batch[1]) self.assertEqual( tf.keras.backend.get_value(optimizer.optimizer._batch_size), 11) @parameterized.named_parameters(('_Adaptive', {'adaptive': True}), ('_AdaptDamping', {'adapt_damping': True})) def testInferredBatchSizeFail(self, kfac_kwargs): dataset = tf.data.Dataset.from_tensors(([1.], [1.])) dataset = dataset.repeat().batch(11, drop_remainder=False) train_batch = dataset.make_one_shot_iterator().get_next() with self.assertRaisesRegex(ValueError, 'Could not infer batch_size.*'): optimizer = optimizers.Kfac(damping=10., train_batch=train_batch, **kfac_kwargs) def testOverrideAdaptiveDefaults(self): dataset = tf.data.Dataset.from_tensors(([1.], [1.])) dataset = dataset.repeat().batch(11, drop_remainder=False) train_batch = dataset.make_one_shot_iterator().get_next() model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))]) loss = 'mse' train_batch = dataset.make_one_shot_iterator().get_next() optimizer = optimizers.Kfac(damping=10., adaptive=True, train_batch=train_batch, model=model, batch_size=11, invert_every=1, damping_adaptation_interval=2, loss=loss, qmodel_update_rescale=0.01) model.compile(optimizer, loss) model.train_on_batch(train_batch[0], train_batch[1]) self.assertEqual(optimizer.optimizer._invert_every, 1) self.assertEqual(optimizer.optimizer._damping_adaptation_interval, 2) @parameterized.named_parameters(('_Adaptive', {'adaptive': True}), ('_Qmodel', {'momentum_type': 'qmodel'})) def testAdaptiveWithLR(self, kfac_kwargs): dataset = tf.data.Dataset.from_tensors(([1.], [1.])) dataset = dataset.repeat().batch(11, drop_remainder=True) train_batch = dataset.make_one_shot_iterator().get_next() with self.assertRaisesRegex(ValueError, 'learning_rate must be None.*'): optimizer = optimizers.Kfac(damping=10., train_batch=train_batch, learning_rate=0.1, **kfac_kwargs) def testCustomLossFn(self): rands = lambda: np.random.random((100, 1)).astype(np.float32) dataset = tf.data.Dataset.from_tensor_slices((rands(), rands())) dataset = dataset.repeat().batch(10, drop_remainder=True) train_batch = dataset.make_one_shot_iterator().get_next() model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))]) def loss_fn(inputs): mse = tf.keras.losses.mean_squared_error(model(inputs[0]), inputs[1]) return tf.reduce_mean(mse) loss = 'mse' train_batch = dataset.make_one_shot_iterator().get_next() optimizer = optimizers.Kfac(damping=10., train_batch=train_batch, adaptive=True, model=model, loss=loss, loss_fn=loss_fn, qmodel_update_rescale=0.01) model.compile(optimizer, loss) model.fit(train_batch, steps_per_epoch=10, epochs=1) self.assertEqual(loss_fn, optimizer.optimizer._loss_fn) def testRegisterTrainBatch(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28,28,1)), layers.BatchNormalization(), layers.Conv2D(23, 3), layers.LayerNormalization(), layers.GlobalMaxPool2D(), layers.Dense(10, activation='softmax', name='output_test') ]) loss = 'categorical_crossentropy' optimizer = optimizers.Kfac(damping=10., adaptive=True, model=model, loss=loss) model.compile(optimizer, loss) train_batch = _get_synthetic_mnist_train_tensors(drop_remainder=True) optimizer.register_train_batch(train_batch) model.train_on_batch(x=train_batch[0], y=train_batch[1]) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/keras_saving_utils_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #,============================================================================ """Tests for keras/saving_utils.py. These tests were forked from the hdf5_format_test.py tests in Keras. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tempfile import numpy as np import tensorflow.compat.v1 as tf from tensorflow.python.framework import test_util from kfac.python.keras import optimizers from kfac.python.keras import saving_utils try: import h5py # pylint:disable=g-import-not-at-top except ImportError: h5py = None keras = tf.keras _KFAC_KWARGS = { 'learning_rate': 0.0001, 'damping': 0.01, 'momentum': 0.85, 'fisher_approx': { keras.layers.Dense: 'kron_in_diag', }, 'loss': 'mse', # This seed is necessary to keep the optimizer updates deterministic, since # we're approximating the true Fisher by sampling the targets. Since for # many tests we only do one training step, the approximations can vary # significantly without a set seed. 'seed': 1234, } class SavingUtilsTest(tf.test.TestCase): @test_util.run_v1_only('b/120994067') def test_sequential_model_saving(self): if h5py is None: self.skipTest('h5py required to run this test') with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(2,))) model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(3)) model.compile( loss=keras.losses.MSE, optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS), metrics=[ keras.metrics.categorical_accuracy, keras.metrics.CategoricalAccuracy() ]) x = np.random.random((1, 2)) y = np.random.random((1, 3)) # TODO(b/136561651): Since we use TFP distributions to sample from the # output distribution, optimizer's won't match exactly unless they are run # for the same number of steps. Even with a random seed, the internal # state of TFP changes with each call. We must switch to a stateless # sampler. Uncomment the train line below once this is implemented. # model.train_on_batch(x, y) out = model.predict(x) fd, fname = tempfile.mkstemp('.h5') keras.models.save_model(model, fname) new_model = saving_utils.load_model(fname, optimizer_name='new') os.close(fd) os.remove(fname) out2 = new_model.predict(x) self.assertAllClose(out, out2, atol=1e-05) # test that new updates are the same with both models x = np.random.random((1, 2)) y = np.random.random((1, 3)) model.train_on_batch(x, y) new_model.train_on_batch(x, y) x = np.random.random((1, 2)) y = np.random.random((1, 3)) eval_out = model.evaluate(x, y) eval_out2 = new_model.evaluate(x, y) self.assertArrayNear(eval_out, eval_out2, 1e-03) out = model.predict(x) out2 = new_model.predict(x) self.assertAllClose(out, out2, atol=1e-05) @test_util.run_deprecated_v1 def test_functional_model_saving(self): if h5py is None: self.skipTest('h5py required to run this test') with self.cached_session(): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) output = keras.layers.Dense(3)(x) model = keras.models.Model(inputs, output) model.compile( loss=keras.losses.MSE, optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS), metrics=[ keras.metrics.categorical_accuracy, keras.metrics.CategoricalAccuracy() ], weighted_metrics=[ keras.metrics.categorical_accuracy, keras.metrics.CategoricalAccuracy() ]) x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) out = model.predict(x) fd, fname = tempfile.mkstemp('.h5') keras.models.save_model(model, fname) model = saving_utils.load_model(fname, optimizer_name='new') os.close(fd) os.remove(fname) out2 = model.predict(x) self.assertAllClose(out, out2, atol=1e-05) def test_saving_model_with_long_layer_names(self): if h5py is None: self.skipTest('h5py required to run this test') with self.cached_session(): # This layer name will make the `layers_name` HDF5 attribute blow # out of proportion. Note that it fits into the internal HDF5 # attribute memory limit on its own but because h5py converts # the list of layer names into numpy array, which uses the same # amount of memory for every item, it increases the memory # requirements substantially. x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15))) f = x for i in range(4): f = keras.layers.Dense(2, name='dense_%d' % (i,))(f) model = keras.Model(inputs=[x], outputs=[f]) model.compile(optimizers.Kfac(model=model, **_KFAC_KWARGS), loss=keras.losses.MeanSquaredError(), metrics=['acc']) x = np.random.random((1, 2)) y = np.random.random((1, 2)) model.train_on_batch(x, y) out = model.predict(x) fd, fname = tempfile.mkstemp('.h5') keras.models.save_model(model, fname) model = saving_utils.load_model(fname, optimizer_name='new') # Check that the HDF5 files contains chunked array # of layer names. with h5py.File(fname, 'r') as h5file: num_names_arrays = len([attr for attr in h5file['model_weights'].attrs if attr.startswith('layer_names')]) # The chunking of layer names array should have happened. self.assertGreater(num_names_arrays, 0) out2 = model.predict(x) self.assertAllClose(out, out2, atol=1e-05) # Cleanup os.close(fd) os.remove(fname) def test_saving_model_with_long_weights_names(self): self.skipTest('KFAC does not support nested models yet.') if h5py is None: self.skipTest('h5py required to run this test') with self.cached_session(): x = keras.Input(shape=(2,), name='nested_model_input') f = x for i in range(4): f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f) # This layer name will make the `weights_name` # HDF5 attribute blow out of proportion. f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f) nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model') x = keras.Input(shape=(2,), name='outer_model_input') f = nested_model(x) f = keras.layers.Dense(2, name='outer_model_output')(f) model = keras.Model(inputs=[x], outputs=[f]) model.compile(loss='mse', optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS), metrics=['acc']) x = np.random.random((1, 2)) y = np.random.random((1, 2)) model.train_on_batch(x, y) out = model.predict(x) fd, fname = tempfile.mkstemp('.h5') keras.models.save_model(model, fname) model = saving_utils.load_model(fname, optimizer_name='new') # Check that the HDF5 files contains chunked array # of weight names. with h5py.File(fname, 'r') as h5file: num_weight_arrays = len( [attr for attr in h5file['model_weights']['nested_model'].attrs if attr.startswith('weight_names')]) # The chunking of layer names array should have happened. self.assertGreater(num_weight_arrays, 0) out2 = model.predict(x) self.assertAllClose(out, out2, atol=1e-05) # Cleanup os.close(fd) os.remove(fname) @test_util.run_deprecated_v1 def test_model_saving_to_pre_created_h5py_file(self): if h5py is None: self.skipTest('h5py required to run this test') with self.cached_session(): inputs = keras.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) outputs = keras.layers.Dense(3)(x) model = keras.Model(inputs, outputs) model.compile( loss=keras.losses.MSE, optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS), metrics=[ keras.metrics.categorical_accuracy, keras.metrics.CategoricalAccuracy() ]) x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) out = model.predict(x) fd, fname = tempfile.mkstemp('.h5') with h5py.File(fname, mode='r+') as h5file: keras.models.save_model(model, h5file) loaded_model = saving_utils.load_model(h5file, optimizer_name='new') out2 = loaded_model.predict(x) self.assertAllClose(out, out2, atol=1e-05) # Test non-default options in h5 with h5py.File( '-', driver='core', mode='w', backing_store=False) as h5file: keras.models.save_model(model, h5file) loaded_model = saving_utils.load_model(h5file, optimizer_name='new2') out2 = loaded_model.predict(x) self.assertAllClose(out, out2, atol=1e-05) # Cleanup os.close(fd) os.remove(fname) def test_saving_constant_initializer_with_numpy(self): if h5py is None: self.skipTest('h5py required to run this test') with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.Dense( 2, input_shape=(3,), kernel_initializer=keras.initializers.Constant(np.ones((3, 2))))) model.add(keras.layers.Dense(3)) model.compile(loss='mse', optimizer=optimizers.Kfac(model=model, **_KFAC_KWARGS), metrics=['acc']) fd, fname = tempfile.mkstemp('.h5') keras.models.save_model(model, fname) model = saving_utils.load_model(fname, optimizer_name='new') os.close(fd) os.remove(fname) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/keras_utils_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for keras/utils.py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf from kfac.python.keras import utils from kfac.python.ops import fisher_blocks from kfac.python.ops import loss_functions layers = tf.keras.layers losses = tf.keras.losses _SEED = 1234 def _mlp(): return tf.keras.Sequential([ layers.Embedding(100, 13, input_length=1), layers.Flatten(), layers.Dense(32, activation='tanh'), layers.Dense(32, activation='tanh'), layers.Dense(1) ]) def _cnn(): return tf.keras.Sequential([ layers.Conv2D(7, 5, input_shape=(28, 28, 3)), layers.Activation('relu'), layers.Conv2D(13, (3, 3), activation='relu'), layers.GlobalMaxPool2D(), layers.Activation('softmax') ]) def _two_loss_model(num_branch1_outputs=1, num_branch2_outputs=9): inp = layers.Input(shape=(28, 28, 1)) branch1 = layers.Lambda(lambda x: tf.squeeze(x, -1))(inp) branch1 = layers.Conv1D(13, 7, activation='relu')(branch1) branch1 = layers.GlobalMaxPool1D()(branch1) branch1 = layers.Dense(num_branch1_outputs, name='out1')(branch1) branch2 = layers.Conv2D(16, 3, activation='relu')(inp) branch2 = layers.MaxPooling2D(pool_size=(4, 4))(branch2) branch2 = layers.Flatten()(branch2) branch2 = layers.Dense(num_branch2_outputs, name='out2')(branch2) return inp, (branch1, branch2) class GetLayerCollectionTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super(GetLayerCollectionTest, self).setUp() tf.reset_default_graph() tf.random.set_random_seed(_SEED) @parameterized.named_parameters( ('_Categorical', 'categorical_crossentropy', loss_functions.CategoricalLogitsNegativeLogProbLoss), ('_Binary', 'binary_crossentropy', loss_functions.MultiBernoulliNegativeLogProbLoss), ('_Sparse', losses.sparse_categorical_crossentropy, loss_functions.CategoricalLogitsNegativeLogProbLoss)) def testValidLogitLossFunctionsCNN(self, loss, kfac_loss): """Ensures correct tensorflow_kfac loss function and variable for a CNN. Args: loss: A losses function (in serialized form or actual reference) kfac_loss: tensorflow_kfac.python.ops loss function. """ with tf.Graph().as_default(): model = _cnn() lc = utils.get_layer_collection(model, loss) self.assertIsInstance(lc.losses[0], kfac_loss) self.assertEqual(lc.losses[0].params, utils.get_parent(model.layers[-1].output)) @parameterized.named_parameters( ('_Categorical', 'categorical_crossentropy', loss_functions.CategoricalLogitsNegativeLogProbLoss), ('_Binary', 'binary_crossentropy', loss_functions.MultiBernoulliNegativeLogProbLoss), ('_Sparse', losses.sparse_categorical_crossentropy, loss_functions.CategoricalLogitsNegativeLogProbLoss)) def testValidLogitLossFunctionsMLP(self, loss, kfac_loss): """Ensures correct tensorflow_kfac loss function and variable for a MLP. Args: loss: A losses function (in serialized form or actual reference) kfac_loss: tensorflow_kfac.python.ops loss function. """ with tf.Graph().as_default(): model = _mlp() lc = utils.get_layer_collection(model, loss) self.assertIsInstance(lc.losses[0], kfac_loss) self.assertEqual(lc.losses[0].params, model.layers[-1].output) @parameterized.named_parameters(('_LongCNN', 'mean_squared_error', _cnn), ('ShortCNN', 'mse', _cnn), ('_LongMLP', losses.mean_squared_error, _mlp), ('ShortMLP', 'mse', _mlp), ('_Class', losses.MeanSquaredError(), _mlp)) def testValidMSE(self, loss, model_builder): """Ensures variations of MSE and output variables work. Args: loss: A tf.keras.losses function (in serialized form or actual reference) model_builder: Function that returns a Keras model. """ model = model_builder() lc = utils.get_layer_collection(model, loss) self.assertIsInstance(lc.losses[0], loss_functions.NormalMeanNegativeLogProbLoss) self.assertEqual(lc.losses[0].params, model.layers[-1].output) @parameterized.named_parameters(('_NotRealLoss', 'blah blah blah'), ('_RealButInvalid', 'cosine'), ('_SimilarName', 'msle')) def testInvalidLossFunctions(self, loss): with self.assertRaisesRegex(ValueError, '.*loss function:.*'): model = _mlp() utils.get_layer_collection(model, loss) @parameterized.named_parameters(('_CNN', _cnn), ('_MLP', _mlp)) def testLayerRegistration(self, model_builder): model = model_builder() model.layers[0].trainable = False lc = utils.get_layer_collection(model, 'mse') registered = set(lc.registered_variables) variables = set() for layer in model.layers[1:]: if layer.trainable and layer.count_params(): variables |= set(layer.weights) self.assertEqual(registered, variables) @parameterized.named_parameters( ('_DictLoss', {'out1': 'binary_crossentropy', 'out2': 'categorical_crossentropy'}, {'out1': 0.1, 'out2': 0.9}), ('_ListLoss', ['binary_crossentropy', 'categorical_crossentropy'], [0.1, 0.9])) def testMultipleLoss(self, loss, loss_weights): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights) self.assertLen(lc.loss_coeffs.keys(), 2) self.assertLen(lc.loss_colocation_ops.keys(), 2) l1 = lc._loss_dict['sigmoid_cross_entropy_loss'] l2 = lc._loss_dict['sparse_softmax_cross_entropy_loss'] self.assertLen(l1, 1) self.assertLen(l2, 1) l1, l2 = l1[0], l2[0] self.assertIsInstance(l1, loss_functions.MultiBernoulliNegativeLogProbLoss) self.assertIsInstance(l2, loss_functions.CategoricalLogitsNegativeLogProbLoss) self.assertEqual(lc.loss_coeffs[l1], 0.1) self.assertEqual(lc.loss_coeffs[l2], 0.9) self.assertEqual(lc.loss_colocation_ops[l1], out1) self.assertEqual(lc.loss_colocation_ops[l2], out2) self.assertEqual(lc.loss_coeffs[l1], 0.1) self.assertEqual(lc.loss_coeffs[l2], 0.9) @parameterized.named_parameters(('_EmptyDict', {}), ('_PartialDict', {'out2': 0.3})) def testMultipleLossWeights(self, loss_weights): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) loss = ['binary_crossentropy', 'categorical_crossentropy'] lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights) l1 = lc._loss_dict['sigmoid_cross_entropy_loss'][0] self.assertEqual(lc.loss_coeffs[l1], 1.0) @parameterized.named_parameters( ('_MissingDict', {'out2': 'categorical_crossentropy'}), ('_MissingList', ['categorical_crossentropy']), ('_ExtraDict', {'out1': 'binary_crossentropy', 'out2': 'categorical_crossentropy', 'blah': 'mse'}), ('_ExtraList', ['mse', 'binary_crossentropy', 'categorical_crossentropy']), ('_WrongName', {'out1': 'binary_crossentropy', 'path2': 'categorical_crossentropy'})) def testLossErrors(self, loss): with self.assertRaisesRegex(ValueError, '.*loss dict.*'): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) utils.get_layer_collection(model, loss) @parameterized.named_parameters( ('_EmptyList', []), ('_MissingList', [0.1]), ('_ExtraList', [0.1, 0.9, 0.3]), ('_ExtraDict', {'out1': 0.1, 'out2': 0.9, 'blahblah': 0.4}), ('_Set', {0.1, 0.3})) def testLossWeightErrors(self, loss_weights): with self.assertRaisesRegex(ValueError, '.*loss_weights.*'): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) loss = ['binary_crossentropy', 'categorical_crossentropy'] utils.get_layer_collection(model, loss, loss_weights=loss_weights) @parameterized.named_parameters( ('_Seperable', layers.SeparableConv2D(13, 5)), ('_ChannelsFirst', layers.Conv2D(11, 3, data_format='channels_first'))) def testInvalidCNNLayers(self, layer): with self.assertRaises(ValueError): model = tf.keras.Sequential([layers.Input(shape=(28, 28, 3)), layer]) utils.get_layer_collection(model, 'mse') @parameterized.named_parameters( ('_List', ['kron', 'kron_in_diag', 'kron_out_diag', 'kron_both_diag']), ('_Dict', {'l1': 'kron', 'l2': 'kron_in_diag', 'l3': 'kron_out_diag', 'l4': 'kron_both_diag'}), ('_DictOneMissing', {'l2': 'kron_in_diag', 'l3': 'kron_out_diag', 'l4': 'kron_both_diag'})) def testFisherApproxLayerNames(self, fisher_approx): model = tf.keras.Sequential([ layers.Dense(10, input_shape=(20,), name='l1'), layers.Activation('relu'), layers.Dense(13, activation='relu', name='l2'), layers.Dense(23, trainable=False), layers.Dense(17, name='l3'), layers.Activation('relu'), layers.Dense(3, name='l4')]) lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) trainable_layers = [model.layers[i] for i in [0, 2, 4, 6]] expected_in_diag_approx = [False, True, False, True] expected_out_diag_approx = [False, False, True, True] for layer, in_diag, out_diag in zip(trainable_layers, expected_in_diag_approx, expected_out_diag_approx): self.assertEqual( in_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_input) self.assertEqual( out_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_output) @parameterized.named_parameters( ('_ClassOnly', {layers.Conv2D: 'diagonal'}, (fisher_blocks.ConvDiagonalFB, fisher_blocks.ConvDiagonalFB)), ('_NameAndClass', {layers.Conv2D: 'diagonal', 'conv2d_1': None}, (fisher_blocks.ConvDiagonalFB, fisher_blocks.ConvKFCBasicFB))) def testFisherApproxLayerClass(self, fisher_approx, block_types): model = _cnn() lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) trainable_layers = [model.layers[0], model.layers[2]] for layer, block_type in zip(trainable_layers, block_types): self.assertIsInstance(lc.fisher_blocks[layer.weights], block_type) @parameterized.named_parameters( ('_EmptyList', []), ('_ExtraDict', {'conv2d': 'diagonal', layers.Conv2D: 'kron', 'UWaterloo': 'kron'}), ('_ExtraList', ['kron', 'diagonal', 'diagonal']), ('_WrongName', {'conv2d': 'kron', 'path2': 'kron'})) def testFisherApproxErrors(self, fisher_approx): with self.assertRaisesRegex(ValueError, '.*fisher_approx.*'): utils.get_layer_collection(_cnn(), 'mse', fisher_approx=fisher_approx) @parameterized.named_parameters( ('_List', ['full', 'diagonal'], ['full', 'diagonal']), ('_SerializedDict', {'dense1': 'full', 'dense2': 'diagonal'}, {'dense1': 'full', 'dense2': 'diagonal'}), ('_PartiallySerializedDict', {layers.Dense: 'full', utils._CLASS_NAME_PREFIX + 'Conv2D': 'full'}, {utils._CLASS_NAME_PREFIX + 'Dense': 'full', utils._CLASS_NAME_PREFIX + 'Conv2D': 'full'}), ('_Dict', {layers.Dense: 'diagonal', layers.Conv2D: 'full'}, {utils._CLASS_NAME_PREFIX + 'Dense': 'diagonal', utils._CLASS_NAME_PREFIX + 'Conv2D': 'full'})) def testSerializeFisherApprox(self, approx, correctly_serialized_approx): serialized_approx = utils.serialize_fisher_approx(approx) self.assertEqual(serialized_approx, correctly_serialized_approx) def testSeed(self): lc = utils.get_layer_collection(model=_mlp(), loss='mse', seed=4321) self.assertEqual(lc._loss_dict['squared_error_loss'][0]._default_seed, 4321) @parameterized.named_parameters(('_HasShift', True), ('_NoShift', False)) def testNormalizationLayers(self, has_shift): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.BatchNormalization(center=has_shift, name='bn'), layers.Conv2D(23, 3), layers.LayerNormalization(center=has_shift), layers.GlobalMaxPool2D(), ]) fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'} lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) bn_weights = model.layers[1].trainable_weights ln_weights = model.layers[3].trainable_weights if not has_shift: bn_weights, ln_weights = bn_weights[0], ln_weights[0] bn_block = lc.fisher_blocks[bn_weights] ln_block = lc.fisher_blocks[ln_weights] self.assertIsInstance(bn_block, fisher_blocks.ScaleAndShiftDiagonalFB) self.assertIsInstance(ln_block, fisher_blocks.ScaleAndShiftFullFB) self.assertEqual(bn_block._has_shift, has_shift) self.assertEqual(ln_block._has_shift, has_shift) def testErrorWithBatchNormNoScale(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.BatchNormalization(scale=False, fused=False), layers.GlobalMaxPool2D(), ]) with self.assertRaisesRegex(ValueError, '.*scale=False.*'): utils.get_layer_collection(model, 'binary_crossentropy') def testErrorWithLayerNormNoScale(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.LayerNormalization(scale=False), layers.GlobalMaxPool2D(), ]) with self.assertRaisesRegex(ValueError, '.*scale=False.*'): utils.get_layer_collection(model, 'binary_crossentropy') def testNumBatchNormUsesWithPhase(self): tf.keras.backend.set_learning_phase(1) model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.BatchNormalization(fused=True), layers.GlobalMaxPool2D(), ]) lc = utils.get_layer_collection(model, 'binary_crossentropy') for w in model.layers[1].trainable_weights: self.assertEqual(lc._vars_to_uses[w], 1) def testNumBatchNormUsesNoPhase(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.BatchNormalization(fused=True), layers.GlobalMaxPool2D(), ]) lc = utils.get_layer_collection(model, 'binary_crossentropy') for w in model.layers[1].trainable_weights: self.assertEqual(lc._vars_to_uses[w], 2) def testModelAsCallable(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5), layers.BatchNormalization(name='bn', fused=False), layers.Conv2D(23, 3), layers.LayerNormalization(), layers.GlobalMaxPool2D(), ]) inp = tf.random_normal((10, 28, 28, 3)) inp = tf.keras.Input(tensor=inp) inp2 = tf.random_normal((10, 28, 28, 3)) inp2 = tf.keras.Input(tensor=inp2) fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'} _ = model(inp) _ = model(inp2) # with multiple calls, the latest should be registered. lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) for i in (0, 2): conv_block = lc.fisher_blocks[model.layers[i].trainable_weights] conv_inp = model.layers[i].inbound_nodes[-1].input_tensors conv_out = model.layers[i].inbound_nodes[-1].output_tensors self.assertEqual(conv_inp, conv_block._inputs[0]) self.assertEqual(conv_out, conv_block._outputs[0]) @parameterized.named_parameters( ('_DictApprox', {layers.Dense: 'kron_in_diag', 'l1': 'kron_out_diag', 'l3': 'kron_both_diag'}), ('_ListApprox', ['kron_out_diag', 'kron_in_diag', 'kron_both_diag'])) def testNestedModels(self, fisher_approx): # Note this is not a valid trainable model, it was just created to test # order of the dict and list test the DFS order in utils as well. layer1 = layers.Dense(10, input_shape=(1,), name='l1') layer2 = layers.Dense(10, activation='relu', name='l2') layer3 = layers.Dense(10, activation='relu', name='l3') inner_model0 = tf.keras.Sequential([layer1]) inner_model1 = tf.keras.Sequential() inner_model1.add(inner_model0) inner_model1.add(layers.Activation('relu')) inner_model1.add(layer2) inner_inp = layers.Input(shape=(1,)) x = layer3(inner_inp) x = layers.Reshape(target_shape=(10, 1))(x) x = layers.GlobalMaxPool1D()(x) inner_model2 = tf.keras.Model(inputs=inner_inp, outputs=x) inp = layers.Input(shape=(1,)) branch1 = inner_model1(inp) branch2 = inner_model2(inp) out = layers.Add()([branch1, branch2]) model = tf.keras.Model(inputs=inp, outputs=out) lc = utils.get_layer_collection( model=model, loss='mse', fisher_approx=fisher_approx) expected_in_diag_approx = [False, True, True] expected_out_diag_approx = [True, False, True] trainable_layers = [layer1, layer2, layer3] for layer, in_diag, out_diag in zip(trainable_layers, expected_in_diag_approx, expected_out_diag_approx): self.assertIsInstance(lc.fisher_blocks[layer.weights], fisher_blocks.FullyConnectedKFACBasicFB) self.assertEqual( in_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_input) self.assertEqual( out_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_output) def testMultiOutputNestedModelFails(self): inp = tf.keras.Input(shape=(1,)) out1 = layers.Dense(1)(inp) out2 = layers.Dense(1)(inp) model = tf.keras.Model(inputs=inp, outputs=[out1, out2]) inp2 = tf.keras.Input(shape=(1,)) out = model(inp2) model = tf.keras.Model(inputs=inp2, outputs=out) with self.assertRaisesRegex( ValueError, 'Nested models with multiple outputs are unsupported.'): utils.get_layer_collection(model, loss=['mse', 'mse']) class SerializeLossTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('_String', 'binary_crossentropy', 'binary_crossentropy'), ('_KerasLoss', losses.binary_crossentropy, 'binary_crossentropy'), ('_Dict', {'out1': 'binary_crossentropy', 'out2': losses.mean_squared_error}, {'out1': 'binary_crossentropy', 'out2': 'mean_squared_error'}), ('_List', ['mse', tf.keras.losses.categorical_crossentropy], ['mse', 'categorical_crossentropy'])) def testSerializeLoss(self, loss, correctly_serialized_loss): serialized_loss = utils.serialize_loss(loss) self.assertEqual(serialized_loss, correctly_serialized_loss) class GetLossFnTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super(GetLossFnTest, self).setUp() tf.reset_default_graph() tf.random.set_random_seed(_SEED) @parameterized.parameters( ('categorical_crossentropy', (11, 10), True, True), ('sparse_categorical_crossentropy', (11,), True, False), ('categorical_crossentropy', (11, 10), False, True), ('sparse_categorical_crossentropy', (11,), False, False), (losses.CategoricalCrossentropy(), (11, 10), True, True), (losses.categorical_crossentropy, (11, 10), False, True)) def testCrossEntropy(self, loss, label_shape, is_logits, use_regularization): conv_kwargs = {'kernel_regularizer': 'l2'} if use_regularization else {} model_layers = [ layers.Conv2D(7, 5, input_shape=(32, 32, 3), **conv_kwargs), layers.Activation('relu'), layers.Conv2D(10, (3, 3), activation='relu', **conv_kwargs), layers.GlobalMaxPool2D() ] if is_logits: model_layers.append(layers.Activation('softmax')) model = tf.keras.Sequential(model_layers) model.compile('sgd', loss) loss_fn = utils.get_loss_fn(model=model, loss=loss) x = tf.constant(np.random.random((11, 32, 32, 3)).astype(np.float32)) y = tf.constant(np.random.random(label_shape).astype(np.float32)) model_loss = model.evaluate(x, y, steps=1) fn_loss = tf.keras.backend.get_value(loss_fn((x, y))) fn_loss_w_pred = tf.keras.backend.get_value( loss_fn((x, y), prediction=model(x))) self.assertAlmostEqual(model_loss, fn_loss, places=5) self.assertAlmostEqual(fn_loss, fn_loss_w_pred, places=5) model.train_on_batch(np.random.random((11, 32, 32, 3)), np.random.random(label_shape)) x = tf.constant(np.random.random((11, 32, 32, 3)).astype(np.float32)) y = tf.constant(np.random.random(label_shape).astype(np.float32)) model_loss = model.test_on_batch(x, y) fn_loss = tf.keras.backend.get_value(loss_fn((x, y))) fn_loss_w_pred = tf.keras.backend.get_value( loss_fn((x, y), prediction=model(x))) self.assertAlmostEqual(model_loss, fn_loss, places=5) self.assertAlmostEqual(fn_loss, fn_loss_w_pred, places=5) @parameterized.parameters('categorical_crossentropy', losses.CategoricalCrossentropy(), losses.CategoricalCrossentropy(from_logits=False), losses.categorical_crossentropy) def testCrossEntropyCustomLoop(self, loss): model_layers = [ layers.Conv2D(7, 5, input_shape=(32, 32, 3)), layers.Activation('relu'), layers.Conv2D(10, (3, 3), kernel_regularizer='l2'), layers.GlobalMaxPool2D() ] model = tf.keras.Sequential(model_layers) model.compile('sgd', loss) loss_fn = utils.get_loss_fn(model=model, loss=loss) x = np.random.random((11, 32, 32, 3)).astype(np.float32) y = np.random.random((11, 10)).astype(np.float32) tf_x = tf.constant(x) tf_y = tf.constant(y) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) model_loss = sess.run( model.total_loss, feed_dict={'conv2d_input:0': x, 'global_max_pooling2d_target:0': y}) fn_loss = sess.run(loss_fn((tf_x, tf_y))) fn_loss_w_pred = sess.run(loss_fn((tf_x, tf_y), prediction=model(tf_x))) self.assertAlmostEqual(model_loss, fn_loss, fn_loss_w_pred) @parameterized.parameters( 'mse', 'MSE', 'mean_squared_error', losses.mean_squared_error) def testMSE(self, loss): model = _mlp() model.compile('sgd', loss) loss_fn = utils.get_loss_fn(model=model, loss=loss) x = tf.constant(np.random.random((23, 1)).astype(np.float32)) y = tf.constant(np.random.random((23, 1)).astype(np.float32)) model_loss = model.test_on_batch(x, y) fn_loss = tf.keras.backend.get_value(loss_fn((x, y))) fn_loss_w_pred = tf.keras.backend.get_value( loss_fn((x, y), prediction=model(x))) self.assertAlmostEqual(model_loss, fn_loss, fn_loss_w_pred) @parameterized.parameters( ({'out1': 'mse', 'out2': losses.categorical_crossentropy}, [0.3, 0.7]), (['categorical_crossentropy', losses.MeanSquaredError()], {'out2': 0.1})) def testMultiLoss(self, multi_loss, loss_weights): inps, outs = _two_loss_model() model = tf.keras.Model(inputs=inps, outputs=outs) model.compile('sgd', multi_loss, loss_weights=loss_weights) loss_fn = utils.get_loss_fn( model=model, loss=multi_loss, loss_weights=loss_weights) x = tf.constant(np.random.random((11, 28, 28, 1)).astype(np.float32)) y_1 = tf.constant(np.random.random((11, 1)).astype(np.float32)) y_2 = tf.constant(np.random.random((11, 9)).astype(np.float32)) # test_on_batch returns the total loss and the two individual losses. # We just want the total, so we use model_loss[0]. model_loss = model.test_on_batch(x, [y_1, y_2])[0] fn_loss = tf.keras.backend.get_value(loss_fn((x, [y_1, y_2]))) fn_loss_w_pred = tf.keras.backend.get_value( loss_fn((x, [y_1, y_2]), prediction=model(x))) self.assertAlmostEqual(model_loss, fn_loss, fn_loss_w_pred) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/layer_collection_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for kfac.layer_collection.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf from kfac.python.ops import fisher_blocks from kfac.python.ops import fisher_factors from kfac.python.ops import layer_collection class MockFisherBlock(object): """A fake FisherBlock.""" num_registered_towers = 2 def __init__(self, name='MockFisherBlock'): self.name = name def __eq__(self, other): return isinstance(other, MockFisherBlock) and other.name == self.name def __hash__(self): return hash(self.name) class LayerParametersDictTest(tf.test.TestCase): def testSetItem(self): """Ensure insertion, contains, retrieval works for supported key types.""" with tf.Graph().as_default(): lp_dict = layer_collection.LayerParametersDict() x = tf.constant(0) y0 = tf.constant(0) y1 = tf.constant(0) z0 = tf.constant(0) z1 = tf.constant(0) keys = [x, (y0, y1), [z0, z1]] for key in keys: lp_dict[key] = key for key in keys: self.assertTrue(key in lp_dict) self.assertEqual(lp_dict[key], key) def testSetItemOverlap(self): """Ensure insertion fails if key overlaps with existing key.""" with tf.Graph().as_default(): lp_dict = layer_collection.LayerParametersDict() x = tf.constant(0) y = tf.constant(0) lp_dict[x] = 'value' with self.assertRaises(ValueError): lp_dict[(x, y)] = 'value' # Ensure 'y' wasn't inserted. self.assertTrue(x in lp_dict) self.assertFalse(y in lp_dict) class LayerCollectionTest(tf.test.TestCase): def testLayerCollectionInit(self): lc = layer_collection.LayerCollection() self.assertEqual(0, len(lc.get_blocks())) self.assertEqual(0, len(lc.get_factors())) self.assertFalse(lc.losses) def testRegisterBlocks(self): with tf.Graph().as_default(): tf.set_random_seed(200) lc = layer_collection.LayerCollection() lc.register_fully_connected( tf.constant(1), tf.constant(2), tf.constant(3)) lc.register_fully_connected( tf.constant(1), tf.constant(2), tf.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_conv2d( params=tf.ones((2, 3, 4, 5)), strides=[1, 1, 1, 1], padding='SAME', inputs=tf.ones((1, 2, 3, 4)), outputs=tf.ones((1, 1, 1, 5))) lc.register_conv2d( params=tf.ones((2, 3, 4, 5)), strides=[1, 1, 1, 1], padding='SAME', inputs=tf.ones((1, 2, 3, 4)), outputs=tf.ones((1, 1, 1, 5)), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_separable_conv2d( depthwise_params=tf.ones((3, 3, 1, 2)), pointwise_params=tf.ones((1, 1, 2, 4)), inputs=tf.ones((32, 5, 5, 1)), depthwise_outputs=tf.ones((32, 5, 5, 2)), pointwise_outputs=tf.ones((32, 5, 5, 4)), strides=[1, 1, 1, 1], padding='SAME') lc.register_convolution( params=tf.ones((3, 3, 1, 8)), inputs=tf.ones((32, 5, 5, 1)), outputs=tf.ones((32, 5, 5, 8)), padding='SAME') lc.register_generic( tf.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( tf.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_fully_connected_multi( tf.constant(1), (tf.constant(2), tf.constant(3)), (tf.constant(4), tf.constant(5))) lc.register_conv2d_multi( params=tf.ones((2, 3, 4, 5)), strides=[1, 1, 1, 1], padding='SAME', inputs=(tf.ones((1, 2, 3, 4)), tf.ones((5, 6, 7, 8))), outputs=(tf.ones((1, 1, 1, 5)), tf.ones((2, 2, 2, 10)))) lc.register_fully_connected_multi( tf.constant((1,)), (tf.constant(2), tf.constant(3)), (tf.constant(4), tf.constant(5)), approx=layer_collection.APPROX_KRONECKER_INDEP_IN_DIAG_NAME) lc.register_fully_connected_multi( tf.constant((1,)), (tf.constant(2), tf.constant(3)), (tf.constant(4), tf.constant(5)), dense_inputs=False, approx=layer_collection.APPROX_KRONECKER_INDEP_IN_DIAG_NAME) self.assertEqual(13, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with tf.Graph().as_default(): tf.set_random_seed(200) lc = layer_collection.LayerCollection() key = tf.constant(1) lc.register_fully_connected(key, tf.constant(2), tf.constant(3)) with self.assertRaises(ValueError) as cm: lc.register_generic(key, 16) self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamNotRegistered(self): x = tf.get_variable('x', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {tf.get_variable('y', initializer=tf.constant(1,)): '1'} lc._register_block(x, 'foo') def testShouldRegisterSingleParamRegistered(self): x = tf.get_variable('x', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: '1'} with self.assertRaises(ValueError) as cm: lc._register_block(x, 'foo') self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamRegisteredInTuple(self): x = tf.get_variable('x', initializer=tf.constant(1,)) y = tf.get_variable('y', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} with self.assertRaises(ValueError) as cm: lc._register_block(x, 'foo') self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamNotRegistered(self): x = tf.get_variable('x', initializer=tf.constant(1,)) y = tf.get_variable('y', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {tf.get_variable('z', initializer=tf.constant(1,)): '1'} lc._register_block((x, y), 'foo') self.assertEqual(set(['1', 'foo']), set(lc.get_blocks())) def testRegisterTupleParamRegistered(self): x = tf.get_variable('x', initializer=tf.constant(1,)) y = tf.get_variable('y', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} with self.assertRaises(ValueError) as cm: lc._register_block((x, y), 'foo') self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterTupleParamRegisteredInSuperset(self): x = tf.get_variable('x', initializer=tf.constant(1,)) y = tf.get_variable('y', initializer=tf.constant(1,)) z = tf.get_variable('z', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y, z): '1'} with self.assertRaises(ValueError) as cm: lc._register_block((x, y), 'foo') self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamSomeRegistered(self): x = tf.get_variable('x', initializer=tf.constant(1,)) y = tf.get_variable('y', initializer=tf.constant(1,)) z = tf.get_variable('z', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} with self.assertRaises(ValueError) as cm: lc._register_block((x, y), MockFisherBlock('foo')) self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = tf.get_variable('x', initializer=tf.constant(1,)) y = tf.get_variable('y', initializer=tf.constant(1,)) z = tf.get_variable('z', initializer=tf.constant(1,)) w = tf.get_variable('w', initializer=tf.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, z): '1', (z, w): '2'} with self.assertRaises(ValueError) as cm: lc._register_block((x, y), 'foo') self.assertIn('was already registered', str(cm.exception)) def testRegisterCategoricalPredictiveDistribution(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) logits = tf.eye(2) lc = layer_collection.LayerCollection() lc.register_categorical_predictive_distribution(logits, seed=200) single_loss = sess.run(lc.total_sampled_loss()) lc2 = layer_collection.LayerCollection() lc2.register_categorical_predictive_distribution(logits, seed=200) lc2.register_categorical_predictive_distribution(logits, seed=200) double_loss = sess.run(lc2.total_sampled_loss()) self.assertAlmostEqual(2 * single_loss, double_loss) def testLossFunctionByName(self): """Ensure loss functions can be identified by name.""" with tf.Graph().as_default(): logits = tf.eye(2) lc = layer_collection.LayerCollection() # Create a new loss function by name. lc.register_categorical_predictive_distribution(logits, name='loss1') self.assertEqual(1, len(lc.towers_by_loss)) # Add logits to same loss function. lc.register_categorical_predictive_distribution( logits, name='loss1', reuse=True) self.assertEqual(1, len(lc.towers_by_loss)) # Add another new loss function. lc.register_categorical_predictive_distribution(logits, name='loss2') self.assertEqual(2, len(lc.towers_by_loss)) def testLossFunctionWithoutName(self): """Ensure loss functions get unique names if 'name' not specified.""" with tf.Graph().as_default(): logits = tf.eye(2) lc = layer_collection.LayerCollection() # Create a new loss function with default names. lc.register_categorical_predictive_distribution(logits) lc.register_categorical_predictive_distribution(logits) self.assertEqual(2, len(lc.losses)) def testCategoricalPredictiveDistributionMultipleMinibatches(self): """Ensure multiple minibatches are registered.""" with tf.Graph().as_default(): batch_size = 3 output_size = 2 logits = tf.zeros([batch_size, output_size]) targets = tf.ones([batch_size], dtype=tf.int32) lc = layer_collection.LayerCollection() # Create a new loss function. lc.register_categorical_predictive_distribution( logits, targets=targets, name='loss1') # Can add when reuse=True lc.register_categorical_predictive_distribution( logits, targets=targets, name='loss1', reuse=True) # Can add when reuse=VARIABLE_SCOPE and reuse=True there. with tf.variable_scope(tf.get_variable_scope(), reuse=True): lc.register_categorical_predictive_distribution( logits, targets=targets, name='loss1', reuse=layer_collection.VARIABLE_SCOPE) # Can't add when reuse=False with self.assertRaises(KeyError): lc.register_categorical_predictive_distribution( logits, targets=targets, name='loss1', reuse=False) # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. with self.assertRaises(KeyError): lc.register_categorical_predictive_distribution( logits, targets=targets, name='loss1', reuse=layer_collection.VARIABLE_SCOPE) self.assertEqual(len(lc.towers_by_loss), 1) # Three successful registrations. self.assertEqual(len(lc.towers_by_loss[0]), 3) def testRegisterCategoricalPredictiveDistributionBatchSize1(self): with tf.Graph().as_default(): tf.set_random_seed(200) logits = tf.random_normal((1, 2)) lc = layer_collection.LayerCollection() lc.register_categorical_predictive_distribution(logits, seed=200) def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) logits = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32) lc = layer_collection.LayerCollection() targets = tf.constant([0, 1], dtype=tf.int32) lc.register_categorical_predictive_distribution(logits, targets=targets) single_loss = sess.run(lc.total_loss()) self.assertAlmostEqual(1.6265233, single_loss) def testRegisterNormalPredictiveDistribution(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) predictions = tf.constant([[1., 2.], [3., 4]], dtype=tf.float32) lc = layer_collection.LayerCollection() lc.register_normal_predictive_distribution(predictions, 1., seed=200) single_loss = sess.run(lc.total_sampled_loss()) lc2 = layer_collection.LayerCollection() lc2.register_normal_predictive_distribution(predictions, 1., seed=200) lc2.register_normal_predictive_distribution(predictions, 1., seed=200) double_loss = sess.run(lc2.total_sampled_loss()) self.assertAlmostEqual(2 * single_loss, double_loss) def testRegisterNormalPredictiveDistributionSpecifiedTargets(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) predictions = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32) lc = layer_collection.LayerCollection() targets = tf.constant([[3., 1.], [4., 2.]], dtype=tf.float32) lc.register_normal_predictive_distribution( predictions, 2.**2, targets=targets) single_loss = sess.run(lc.total_loss()) self.assertAlmostEqual(7.6983433, single_loss) def ensureLayerReuseWorks(self, register_fn): """Ensure the 'reuse' keyword argument function as intended. Args: register_fn: function for registering a layer. Arguments are layer_collection, reuse, and approx. """ # Fails on second if reuse=False. lc = layer_collection.LayerCollection() register_fn(lc) with self.assertRaises(ValueError): register_fn(lc, reuse=False) # Succeeds on second if reuse=True. lc = layer_collection.LayerCollection() register_fn(lc) register_fn(lc, reuse=True) # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. lc = layer_collection.LayerCollection() register_fn(lc) with self.assertRaises(ValueError): register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. lc = layer_collection.LayerCollection() register_fn(lc) with tf.variable_scope(tf.get_variable_scope(), reuse=True): register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) # Fails if block type changes. lc = layer_collection.LayerCollection() register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) with self.assertRaises(ValueError): register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) # Fails if reuse requested but no FisherBlock exists. lc = layer_collection.LayerCollection() with self.assertRaises(ValueError): register_fn(lc, reuse=True) def testRegisterFullyConnectedReuse(self): """Ensure the 'reuse' works with register_fully_connected.""" with tf.Graph().as_default(): inputs = tf.ones([2, 10]) outputs = tf.zeros([2, 5]) params = ( tf.get_variable('w', [10, 5]), # tf.get_variable('b', [5])) def register_fn(lc, **kwargs): lc.register_fully_connected( params=params, inputs=inputs, outputs=outputs, **kwargs) self.ensureLayerReuseWorks(register_fn) def testRegisterConv2dReuse(self): """Ensure the 'reuse' works with register_conv2d.""" with tf.Graph().as_default(): inputs = tf.ones([2, 5, 5, 10]) outputs = tf.zeros([2, 5, 5, 3]) params = ( tf.get_variable('w', [1, 1, 10, 3]), # tf.get_variable('b', [3])) def register_fn(lc, **kwargs): lc.register_conv2d( params=params, strides=[1, 1, 1, 1], padding='SAME', inputs=inputs, outputs=outputs, **kwargs) self.ensureLayerReuseWorks(register_fn) def testReuseWithInvalidRegistration(self): """Invalid registrations shouldn't overwrite existing blocks.""" with tf.Graph().as_default(): inputs = tf.ones([2, 5, 5, 10]) outputs = tf.zeros([2, 5, 5, 3]) w = tf.get_variable('w', [1, 1, 10, 3]) b = tf.get_variable('b', [3]) lc = layer_collection.LayerCollection() lc.register_fully_connected(w, inputs, outputs) self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) with self.assertRaises(ValueError): lc.register_fully_connected((w, b), inputs, outputs, reuse=True) self.assertNotIn((w, b), lc.fisher_blocks) self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) lc.register_fully_connected(w, inputs, outputs, reuse=True) self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) def testMakeOrGetFactor(self): with tf.Graph().as_default(): tf.set_random_seed(200) lc = layer_collection.LayerCollection() key = tf.constant(1) lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16)) lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16)) lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((tf.constant(2),), 16)) self.assertEqual(2, len(lc.get_factors())) variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self.assertTrue( all([var.name.startswith('LayerCollection') for var in variables])) def testMakeOrGetFactorCustomScope(self): with tf.Graph().as_default(): tf.set_random_seed(200) scope = 'Foo' lc = layer_collection.LayerCollection(name=scope) key = tf.constant(1) lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16)) lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((key,), 16)) lc.make_or_get_factor(fisher_factors.NaiveFullFactor, ((tf.constant(2),), 16)) self.assertEqual(2, len(lc.get_factors())) variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self.assertTrue(all([var.name.startswith(scope) for var in variables])) def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): x = tf.get_variable('x', shape=()) y = tf.get_variable('y', shape=()) z = tf.get_variable('z', shape=()) lc = layer_collection.LayerCollection() lc.define_linked_parameters((x, y)) with self.assertRaises(ValueError): lc.define_linked_parameters((x, z)) def testIdentifySubsetPreviouslyRegisteredTensor(self): x = tf.get_variable('x', shape=()) y = tf.get_variable('y', shape=()) lc = layer_collection.LayerCollection() lc.define_linked_parameters((x, y)) with self.assertRaises(ValueError): lc.define_linked_parameters(x) def testSpecifyApproximation(self): w_0 = tf.get_variable('w_0', [10, 10]) w_1 = tf.get_variable('w_1', [10, 10]) b_0 = tf.get_variable('b_0', [10]) b_1 = tf.get_variable('b_1', [10]) x_0 = tf.placeholder(tf.float32, shape=(32, 10)) x_1 = tf.placeholder(tf.float32, shape=(32, 10)) pre_bias_0 = tf.matmul(x_0, w_0) pre_bias_1 = tf.matmul(x_1, w_1) # Build the fully connected layers in the graph. pre_bias_0 + b_0 # pylint: disable=pointless-statement pre_bias_1 + b_1 # pylint: disable=pointless-statement lc = layer_collection.LayerCollection() lc.define_linked_parameters( w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) lc.define_linked_parameters( w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) lc.define_linked_parameters( b_0, approximation=layer_collection.APPROX_FULL_NAME) lc.define_linked_parameters( b_1, approximation=layer_collection.APPROX_FULL_NAME) lc.register_fully_connected(w_0, x_0, pre_bias_0) lc.register_fully_connected( w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) self.assertIsInstance(lc.fisher_blocks[w_0], fisher_blocks.FullyConnectedDiagonalFB) self.assertIsInstance(lc.fisher_blocks[w_1], fisher_blocks.FullyConnectedKFACBasicFB) lc.register_generic(b_0, batch_size=1) lc.register_generic( b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.NaiveFullFB) self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) def testDefaultLayerCollection(self): with tf.Graph().as_default(): # Can't get default if there isn't one set. with self.assertRaises(ValueError): layer_collection.get_default_layer_collection() # Can't set default twice. lc = layer_collection.LayerCollection() layer_collection.set_default_layer_collection(lc) with self.assertRaises(ValueError): layer_collection.set_default_layer_collection(lc) # Same as one set. self.assertTrue(lc is layer_collection.get_default_layer_collection()) # Can set to None. layer_collection.set_default_layer_collection(None) with self.assertRaises(ValueError): layer_collection.get_default_layer_collection() # as_default() is the same as setting/clearing. with lc.as_default(): self.assertTrue(lc is layer_collection.get_default_layer_collection()) with self.assertRaises(ValueError): layer_collection.get_default_layer_collection() if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/loss_functions_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for kfac.loss_functions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np import tensorflow.compat.v1 as tf from kfac.python.ops import loss_functions class InsertSliceInZerosTest(tf.test.TestCase): def testBadShape(self): bad_shaped_ones = tf.ones(shape=[1, 3]) # n.b. shape[1] != 1 with self.assertRaises(ValueError): loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) def test3d(self): input_tensor = tf.constant([[[1, 2]], [[3, 4]]]) expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) with self.test_session() as sess: actual_output_array = sess.run(op) self.assertAllEqual(expected_output_array, actual_output_array) class CategoricalLogitsNegativeLogProbLossTest(tf.test.TestCase): def testSample(self): """Ensure samples can be drawn.""" with tf.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( tf.constant(logits)) sample = loss.sample(42) sample = sess.run(sample) self.assertEqual(sample.shape, (2,)) def testEvaluateOnTargets(self): """Ensure log probability can be evaluated correctly.""" with tf.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) targets = np.asarray([2, 1]).astype(np.int32) loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( tf.constant(logits), targets=tf.constant(targets)) neg_log_prob = loss.evaluate() neg_log_prob = sess.run(neg_log_prob) # Calculate explicit log probability of targets. probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) log_probs = np.log([ probs[0, targets[0]], # probs[1, targets[1]] ]) expected_log_prob = np.sum(log_probs) self.assertAllClose(neg_log_prob, -expected_log_prob) def testEvaluateOnSample(self): """Ensure log probability of a sample can be drawn.""" with tf.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( tf.constant(logits)) neg_log_prob = loss.evaluate_on_sample(42) # Simply ensure this doesn't crash. As the output is random, it's # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) def testMultiplyFisherSingleVector(self): with tf.Graph().as_default(), self.test_session() as sess: logits = np.array([1., 2., 3.]) loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) # the LossFunction.multiply_fisher docstring only says it supports the # case where the vector is the same shape as the input natural parameters # (i.e. the logits here), but here we also test leading dimensions vector = np.array([1., 2., 3.]) vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] probs = np.exp(logits - np.logaddexp.reduce(logits)) fisher = np.diag(probs) - np.outer(probs, probs) for vector in vectors: result = loss.multiply_fisher(vector) expected_result = np.dot(vector, fisher) self.assertAllClose(expected_result, sess.run(result)) def testMultiplyFisherBatch(self): with tf.Graph().as_default(), self.test_session() as sess: logits = np.array([[1., 2., 3.], [4., 6., 8.]]) loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) vector = np.array([[1., 2., 3.], [5., 3., 1.]]) na = np.newaxis probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, keepdims=True)) fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] result = loss.multiply_fisher(vector) expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] self.assertEqual(sess.run(result).shape, logits.shape) self.assertAllClose(expected_result, sess.run(result)) class OnehotCategoricalLogitsNegativeLogProbLossTest(tf.test.TestCase): def testSample(self): """Ensure samples can be drawn.""" with tf.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( tf.constant(logits)) sample = loss.sample(42) sample = sess.run(sample) self.assertEqual(sample.shape, (2, 3)) def testEvaluateOnTargets(self): """Ensure log probability can be evaluated correctly.""" with tf.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) targets = np.asarray([2, 1]).astype(np.int32) loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( tf.constant(logits), targets=tf.one_hot(targets, 3)) neg_log_prob = loss.evaluate() neg_log_prob = sess.run(neg_log_prob) # Calculate explicit log probability of targets. probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) log_probs = np.log([ probs[0, targets[0]], # probs[1, targets[1]] ]) expected_log_prob = np.sum(log_probs) self.assertAllClose(neg_log_prob, -expected_log_prob) def testEvaluateOnSample(self): """Ensure log probability of a sample can be drawn.""" with tf.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( tf.constant(logits)) neg_log_prob = loss.evaluate_on_sample(42) # Simply ensure this doesn't crash. As the output is random, it's # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) if __name__ == "__main__": tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/op_queue_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for kfac.op_queue.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf from kfac.python.ops import op_queue class OpQueueTest(tf.test.TestCase): def testNextOp(self): """Ensures all ops get selected eventually.""" with tf.Graph().as_default(): ops = [ tf.add(1, 2), tf.subtract(1, 2), tf.reduce_mean([1, 2]), ] queue = op_queue.OpQueue(ops, seed=0) with self.test_session() as sess: # Ensure every inv update op gets selected. selected_ops = set([queue.next_op(sess) for _ in ops]) self.assertEqual(set(ops), set(selected_ops)) # Ensure additional calls don't create any new ops. selected_ops.add(queue.next_op(sess)) self.assertEqual(set(ops), set(selected_ops)) if __name__ == "__main__": tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/optimizer_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for kfac.optimizer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np import tensorflow.compat.v1 as tf from kfac.python.ops import fisher_factors as ff from kfac.python.ops import layer_collection as lc from kfac.python.ops import optimizer def dummy_layer_collection(): lcoll = lc.LayerCollection() dummy = tf.constant([1., 2.]) lcoll.register_categorical_predictive_distribution(logits=dummy) return lcoll class OptimizerTest(tf.test.TestCase): def testOptimizerInitInvalidMomentumRegistration(self): with self.assertRaises(ValueError): optimizer.KfacOptimizer( 0.1, 0.2, lc.LayerCollection(), 0.3, momentum_type='foo') def testOptimizerInit(self): with tf.Graph().as_default(): layer_collection = lc.LayerCollection() inputs = tf.ones((2, 1)) * 2 weights_val = np.ones((1, 1), dtype=np.float32) * 3. weights = tf.get_variable('w', initializer=tf.constant(weights_val)) bias = tf.get_variable( 'b', initializer=tf.zeros_initializer(), shape=(1, 1)) output = tf.matmul(inputs, weights) + bias layer_collection.register_fully_connected((weights, bias), inputs, output) logits = tf.tanh(output) targets = tf.constant([[0.], [1.]]) output = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=targets)) layer_collection.register_categorical_predictive_distribution(logits) optimizer.KfacOptimizer( 0.1, 0.2, layer_collection, 0.3, momentum=0.5, momentum_type='regular') def testSquaredFisherNorm(self): with tf.Graph().as_default(), self.test_session() as sess: grads_and_vars = [(tf.constant([[1., 2.], [3., 4.]]), None), (tf.constant([[2., 3.], [4., 5.]]), None)] pgrads_and_vars = [(tf.constant([[3., 4.], [5., 6.]]), None), (tf.constant([[7., 8.], [9., 10.]]), None)] opt = optimizer.KfacOptimizer(0.1, 0.2, dummy_layer_collection(), 0.3) sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) self.assertAlmostEqual(174., sess.run(sq_norm), places=5) def testUpdateClipCoeff(self): with tf.Graph().as_default(), self.test_session() as sess: grads_and_vars = [(tf.constant([[1., 2.], [3., 4.]]), None), (tf.constant([[2., 3.], [4., 5.]]), None)] pgrads_and_vars = [(tf.constant([[3., 4.], [5., 6.]]), None), (tf.constant([[7., 8.], [9., 10.]]), None)] lrate = 0.1 # Note: without rescaling, the squared Fisher norm of the update # is 1.74 # If the update already satisfies the norm constraint, there should # be no rescaling. opt = optimizer.KfacOptimizer( lrate, 0.2, dummy_layer_collection(), 0.3, norm_constraint=10., name='KFAC_1') coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) self.assertAlmostEqual(1., sess.run(coeff), places=5) # If the update violates the constraint, it should be rescaled to # be on the constraint boundary. opt = optimizer.KfacOptimizer( lrate, 0.2, dummy_layer_collection(), 0.3, norm_constraint=0.5, name='KFAC_2') coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) def testUpdateVelocities(self): with tf.Graph().as_default(), self.test_session() as sess: layers = lc.LayerCollection() layers.register_categorical_predictive_distribution(tf.constant([1.0])) opt = optimizer.KfacOptimizer( 0.1, 0.2, layers, 0.3, momentum=0.5, momentum_type='regular') x = tf.get_variable('x', initializer=tf.ones((2, 2))) y = tf.get_variable('y', initializer=tf.ones((2, 2)) * 2) vec1 = tf.ones((2, 2)) * 3 vec2 = tf.ones((2, 2)) * 4 model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) opt_vars = [ v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v not in model_vars ] sess.run(tf.global_variables_initializer()) old_opt_vars = sess.run(opt_vars) # Optimizer vars start out at 0. for opt_var in old_opt_vars: self.assertAllEqual(sess.run(tf.zeros_like(opt_var)), opt_var) sess.run(update_op) new_opt_vars = sess.run(opt_vars) # After one update, the velocities are equal to the vectors. for vec, opt_var in zip([vec1, vec2], new_opt_vars): self.assertAllEqual(sess.run(vec), opt_var) sess.run(update_op) final_opt_vars = sess.run(opt_vars) for first, second in zip(new_opt_vars, final_opt_vars): self.assertFalse(np.equal(first, second).all()) def testApplyGradients(self): with tf.Graph().as_default(), self.test_session() as sess: layer_collection = lc.LayerCollection() inputs = tf.ones((2, 1)) * 2 weights_val = np.ones((1, 1), dtype=np.float32) * 3. weights = tf.get_variable('w', initializer=tf.constant(weights_val)) bias = tf.get_variable( 'b', initializer=tf.zeros_initializer(), shape=(1, 1)) output = tf.matmul(inputs, weights) + bias layer_collection.register_fully_connected((weights, bias), inputs, output) preds = output targets = tf.constant([[0.34], [1.56]]) output = tf.reduce_mean(tf.square(targets - preds)) layer_collection.register_squared_error_loss(preds) opt = optimizer.KfacOptimizer( 0.1, 0.2, layer_collection, cov_ema_decay=0.3, momentum=0.5, momentum_type='regular') (cov_update_thunks, inv_update_thunks) = opt.make_vars_and_create_op_thunks() cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) grads_and_vars = opt.compute_gradients(output, [weights, bias]) all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] op = opt.apply_gradients(grads_and_vars) sess.run(tf.global_variables_initializer()) old_vars = sess.run(all_vars) sess.run(cov_update_ops) sess.run(inv_update_ops) sess.run(op) new_vars = sess.run(all_vars) for old_var, new_var in zip(old_vars, new_vars): self.assertNotEqual(old_var, new_var) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/periodic_inv_cov_update_kfac_opt_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for l.d.tf.optimizers.python.PeriodicInvCovUpdateKfacOpt class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import sonnet as snt import tensorflow.compat.v1 as tf from kfac.python.ops import layer_collection from kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt from kfac.python.ops.tensormatch import graph_search _BATCH_SIZE = 128 def _construct_layer_collection(layers, all_logits, var_list): for idx, logits in enumerate(all_logits): tf.logging.info("Registering logits: %s", logits) with tf.variable_scope(tf.get_variable_scope(), reuse=(idx > 0)): layers.register_categorical_predictive_distribution( logits, name="register_logits") batch_size = all_logits[0].shape.as_list()[0] vars_to_register = var_list if var_list else tf.trainable_variables() graph_search.register_layers(layers, vars_to_register, batch_size) class PeriodicInvCovUpdateKfacOptTest(tf.test.TestCase): def test_train(self): image = tf.random_uniform(shape=(_BATCH_SIZE, 784), maxval=1.) labels = tf.random_uniform(shape=(_BATCH_SIZE,), maxval=10, dtype=tf.int32) labels_one_hot = tf.one_hot(labels, 10) model = snt.Sequential([snt.BatchFlatten(), snt.nets.MLP([128, 128, 10])]) logits = model(image) all_losses = tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=labels_one_hot) loss = tf.reduce_mean(all_losses) layers = layer_collection.LayerCollection() optimizer = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt( invert_every=10, cov_update_every=1, learning_rate=0.03, cov_ema_decay=0.95, damping=100., layer_collection=layers, momentum=0.9, num_burnin_steps=0, placement_strategy="round_robin") _construct_layer_collection(layers, [logits], tf.trainable_variables()) train_step = optimizer.minimize(loss) counter = optimizer.counter max_iterations = 50 with self.test_session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) for iteration in range(max_iterations): sess.run([loss, train_step]) counter_ = sess.run(counter) self.assertEqual(counter_, iteration + 1.0) if __name__ == "__main__": tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/kernel_tests/utils_test.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for kfac.utils.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np import tensorflow.compat.v1 as tf from kfac.python.ops import utils class SequenceDictTest(tf.test.TestCase): def testSequenceDictInit(self): seq_dict = utils.SequenceDict() self.assertFalse(seq_dict._dict) def testSequenceDictInitWithIterable(self): reg_dict = {'a': 'foo', 'b': 'bar'} itr = zip(reg_dict.keys(), reg_dict.values()) seq_dict = utils.SequenceDict(itr) self.assertEqual(reg_dict, seq_dict._dict) def testGetItemSingleKey(self): seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) self.assertEqual('foo', seq_dict['a']) def testGetItemMultipleKeys(self): seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')]) def testSetItemSingleKey(self): seq_dict = utils.SequenceDict() seq_dict['a'] = 'foo' self.assertEqual([('a', 'foo')], seq_dict.items()) def testSetItemMultipleKeys(self): seq_dict = utils.SequenceDict() keys = ('a', 'b', 'c') values = ('foo', 'bar', 'baz') seq_dict[keys] = values self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) class SubGraphTest(tf.test.TestCase): def testBasicGraph(self): a = tf.constant([[1., 2.], [3., 4.]]) b = tf.constant([[5., 6.], [7., 8.]]) c = a + b d = a * b sub_graph = utils.SubGraph((c,)) self.assertTrue(sub_graph.is_member(a)) self.assertTrue(sub_graph.is_member(b)) self.assertTrue(sub_graph.is_member(c)) self.assertFalse(sub_graph.is_member(d)) def testRepeatedAdds(self): a = tf.constant([[1., 2.], [3., 4.]]) b = tf.constant([[5., 6.], [7., 8.]]) c = a + b + a # note that a appears twice in this graph sub_graph = utils.SubGraph((c,)) self.assertTrue(sub_graph.is_member(a)) self.assertTrue(sub_graph.is_member(b)) self.assertTrue(sub_graph.is_member(c)) def testFilterList(self): a = tf.constant([[1., 2.], [3., 4.]]) b = tf.constant([[5., 6.], [7., 8.]]) c = a + b d = a * b sub_graph = utils.SubGraph((c,)) input_list = [b, d] filtered_list = sub_graph.filter_list(input_list) self.assertEqual(filtered_list, [b]) def testVariableUses(self): with tf.Graph().as_default(): var = tf.get_variable('var', shape=[10, 10]) resource_var = tf.get_variable( 'resource_var', shape=[10, 10], use_resource=True) x = tf.zeros([3, 10]) z0 = tf.matmul(x, var) + tf.matmul(x, var) z1 = tf.matmul(x, resource_var) sub_graph = utils.SubGraph((z0, z1)) self.assertEqual(2, sub_graph.variable_uses(var)) self.assertEqual(1, sub_graph.variable_uses(resource_var)) def testVariableUsesRelayOps(self): with tf.Graph().as_default(): a = tf.get_variable("a", shape=[2, 2]) b = tf.get_variable("b", shape=[2, 2]) ai = tf.identity(a) c = tf.matmul(ai, b) d = tf.matmul(ai, b) sub_graph = utils.SubGraph((c, d)) self.assertEqual(2, sub_graph.variable_uses(a)) self.assertEqual(2, sub_graph.variable_uses(b)) class UtilsTest(tf.test.TestCase): def _fully_connected_layer_params(self): weights_part = tf.constant([[1., 2.], [4., 3.]]) bias_part = tf.constant([1., 2.]) return (weights_part, bias_part) def _conv_layer_params(self): weights_shape = 2, 2, 3, 4 biases_shape = weights_shape[-1:] weights = tf.constant(np.random.RandomState(0).randn(*weights_shape)) biases = tf.constant(np.random.RandomState(1).randn(*biases_shape)) return (weights, biases) def testFullyConnectedLayerParamsTupleToMat2d(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) layer_params = self._fully_connected_layer_params() output = utils.layer_params_to_mat2d(layer_params) self.assertListEqual([3, 2], output.get_shape().as_list()) self.assertAllClose( sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]])) def testFullyConnectedLayerParamsTensorToMat2d(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) layer_params = self._fully_connected_layer_params() output = utils.layer_params_to_mat2d(layer_params[0]) self.assertListEqual([2, 2], output.get_shape().as_list()) self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]])) def testConvLayerParamsTupleToMat2d(self): with tf.Graph().as_default(): tf.set_random_seed(200) layer_params = self._conv_layer_params() output = utils.layer_params_to_mat2d(layer_params) self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list()) def testKron(self): with tf.Graph().as_default(), self.test_session() as sess: mat1 = np.array([[1., 2.], [3., 4.]]) mat2 = np.array([[5., 6.], [7., 8.]]) mat1_tf = tf.constant(mat1) mat2_tf = tf.constant(mat2) ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf)) ans_np = np.kron(mat1, mat2) self.assertAllClose(ans_tf, ans_np) def testMat2dToFullyConnectedLayerParamsTuple(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) vector_template = self._fully_connected_layer_params() mat2d = tf.constant([[5., 4.], [3., 2.], [1., 0.]]) output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) self.assertIsInstance(output, tuple) self.assertEqual(len(output), 2) a, b = output self.assertAllClose(a, np.array([[5., 4.], [3., 2.]])) self.assertAllClose(b, np.array([1., 0.])) def testMat2dToFullyConnectedLayerParamsTensor(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) vector_template = self._fully_connected_layer_params()[0] mat2d = tf.constant([[5., 4.], [3., 2.]]) output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) self.assertAllClose(output, np.array([[5., 4.], [3., 2.]])) def testTensorsToColumn(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) vector = tf.constant(np.array([[0., 1.], [2., 3.]])) output = utils.tensors_to_column(vector) self.assertListEqual([4, 1], output.get_shape().as_list()) self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None]) vector = self._fully_connected_layer_params() output = utils.tensors_to_column(vector) self.assertListEqual([6, 1], output.get_shape().as_list()) self.assertAllClose( sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None]) vector = list(vector) vector.append(tf.constant([[6.], [7.], [8.], [9.]])) output = utils.tensors_to_column(vector) self.assertListEqual([10, 1], output.get_shape().as_list()) self.assertAllClose( sess.run(output), np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None]) def testColumnToTensors(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) vector_template = tf.constant(np.array([[0., 1.], [2., 3.]])) colvec = tf.constant(np.arange(4.)[:, None]) output = sess.run(utils.column_to_tensors(vector_template, colvec)) self.assertAllClose(output, np.array([[0., 1.], [2., 3.]])) vector_template = self._fully_connected_layer_params() colvec = tf.constant(np.arange(6.)[:, None]) output = sess.run(utils.column_to_tensors(vector_template, colvec)) self.assertIsInstance(output, tuple) self.assertEqual(len(output), 2) a, b = output self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) self.assertAllClose(b, np.array([4., 5.])) vector_template = list(vector_template) vector_template.append(tf.constant([[6.], [7.], [8.], [9.]])) colvec = tf.constant(np.arange(10.)[:, None]) output = sess.run(utils.column_to_tensors(vector_template, colvec)) self.assertIsInstance(output, tuple) self.assertEqual(len(output), 3) a, b, c = output self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) self.assertAllClose(b, np.array([4., 5.])) self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) def testPosDefInvCholesky(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) np.random.seed(0) square = lambda x: np.dot(x, x.T) size = 3 x = square(np.random.randn(size, size)) damp = 0.1 identity = tf.eye(size, dtype=tf.float64) tf_inv = utils.posdef_inv_cholesky(tf.constant(x), identity, damp) np_inv = np.linalg.inv(x + damp * np.eye(size)) self.assertAllClose(sess.run(tf_inv), np_inv) def testPosDefInvMatrixInverse(self): with tf.Graph().as_default(), self.test_session() as sess: tf.set_random_seed(200) np.random.seed(0) square = lambda x: np.dot(x, x.T) size = 3 x = square(np.random.randn(size, size)) damp = 0.1 identity = tf.eye(size, dtype=tf.float64) tf_inv = utils.posdef_inv_matrix_inverse(tf.constant(x), identity, damp) np_inv = np.linalg.inv(x + damp * np.eye(size)) self.assertAllClose(sess.run(tf_inv), np_inv) def testBatchExecute(self): """Ensure batch_execute runs in a round-robin fashion.""" def increment_var(var): return lambda: var.assign_add(1) with tf.Graph().as_default(), self.test_session() as sess: i = tf.get_variable('i', initializer=0) accumulators = [ tf.get_variable('var%d' % j, initializer=0) for j in range(3) ] thunks = [increment_var(var) for var in accumulators] increment_accumulators = utils.batch_execute(i, thunks, 2) increment_i = i.assign_add(1) sess.run(tf.global_variables_initializer()) # Ensure one op per thunk. self.assertEqual(3, len(increment_accumulators)) # Ensure round-robin execution. values = [] for _ in range(5): sess.run(increment_accumulators) sess.run(increment_i) values.append(sess.run(accumulators)) self.assertAllClose( [ [1, 1, 0], # [2, 1, 1], # [2, 2, 2], # [3, 3, 2], # [4, 3, 3] ], values) def testExtractConvolutionPatches(self): with tf.Graph().as_default(), self.test_session() as sess: batch_size = 10 image_spatial_shape = [9, 10, 11] in_channels = out_channels = 32 kernel_spatial_shape = [5, 3, 3] spatial_strides = [1, 2, 1] spatial_dilation = [1, 1, 1] padding = 'SAME' images = tf.random_uniform( [batch_size] + image_spatial_shape + [in_channels], seed=0) kernel_shape = kernel_spatial_shape + [in_channels, out_channels] kernel = tf.random_uniform(kernel_shape, seed=1) # Ensure shape matches expectation. patches = utils.extract_convolution_patches( images, kernel_shape, padding, strides=spatial_strides, dilation_rate=spatial_dilation) result_spatial_shape = ( patches.shape.as_list()[1:1 + len(image_spatial_shape)]) self.assertEqual(patches.shape.as_list(), [batch_size] + result_spatial_shape + kernel_spatial_shape + [in_channels]) # Ensure extract...patches() + matmul() and convolution() implementation # give the same answer. outputs = tf.nn.convolution( images, kernel, padding, strides=spatial_strides, dilation_rate=spatial_dilation) patches_flat = tf.reshape( patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) kernel_flat = tf.reshape(kernel, [-1, out_channels]) outputs_flat = tf.matmul(patches_flat, kernel_flat) outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) def testExtractPointwiseConv2dPatches(self): with tf.Graph().as_default(), self.test_session() as sess: batch_size = 10 image_height = image_width = 8 in_channels = out_channels = 3 kernel_height = kernel_width = 1 strides = [1, 1, 1, 1] padding = 'VALID' images = tf.random_uniform( [batch_size, image_height, image_width, in_channels], seed=0) kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] kernel = tf.random_uniform(kernel_shape, seed=1) # Ensure shape matches expectation. patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) self.assertEqual(patches.shape.as_list(), [ batch_size, image_height, image_width, kernel_height, kernel_width, in_channels ]) # Ensure extract...patches() + matmul() and conv2d() implementation # give the same answer. outputs = tf.nn.conv2d(images, kernel, strides, padding) patches_flat = tf.reshape( patches, [-1, kernel_height * kernel_width * in_channels]) kernel_flat = tf.reshape(kernel, [-1, out_channels]) outputs_flat = tf.matmul(patches_flat, kernel_flat) outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) class AccumulatorVariableTest(tf.test.TestCase): def test_assign_to_var(self): var_shape = (2, 3) acc_var = utils.AccumulatorVariable( name='test_acc_var', dtype=tf.float32, shape=var_shape) values = [ 3. * tf.ones(shape=var_shape), 7. * tf.ones(shape=var_shape), 11. * tf.ones(shape=var_shape) ] acc_ops = [] accc_ops_after_reset = [] for value in values: acc_ops.append(acc_var.accumulate(value)) for value in values[:2]: accc_ops_after_reset.append(acc_var.accumulate(value)) init_op = tf.global_variables_initializer() with self.test_session() as sess: sess.run([init_op]) for acc_op in acc_ops: sess.run(acc_op) acc_var_value = sess.run(acc_var.value) self.assertAllEqual(acc_var_value, 7.*np.ones(shape=var_shape)) sess.run(acc_var.reset()) for acc_op in accc_ops_after_reset: sess.run(acc_op) acc_var_value = sess.run(acc_var.value) self.assertAllEqual(acc_var_value, 5. * np.ones(shape=var_shape)) def test_accumulation(self): var_shape = (2, 3) acc_var = utils.AccumulatorVariable( name='test_acc_var', shape=var_shape, dtype=tf.float32) values = [ 2. * tf.ones(shape=var_shape), 4. * tf.ones(shape=var_shape), 9. * tf.ones(shape=var_shape) ] acc_ops = [] accc_ops_after_reset = [] for value in values: acc_ops.append( acc_var.accumulate(value)) for value in values[:2]: accc_ops_after_reset.append( acc_var.accumulate(value)) init_op = tf.global_variables_initializer() with self.test_session() as sess: sess.run([init_op]) for acc_op in acc_ops: sess.run([acc_op]) acc_var_value = sess.run(acc_var.read_value_and_reset()) self.assertAllEqual(acc_var_value, 5. * np.ones(shape=var_shape)) for acc_op in accc_ops_after_reset: sess.run([acc_op]) acc_var_value = sess.run(acc_var.value) self.assertAllEqual(acc_var_value, 3. * np.ones(shape=var_shape)) if __name__ == '__main__': tf.disable_v2_behavior() tf.test.main() ================================================ FILE: kfac/python/ops/__init__.py ================================================ ================================================ FILE: kfac/python/ops/curvature_matrix_vector_products.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Curvature matrix-vector multiplication.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf from tensorflow.python.util import nest from kfac.python.ops import utils class CurvatureMatrixVectorProductComputer(object): """Class for computing matrix-vector products for Fishers and GGNs. In other words we compute M*v where M is the matrix, v is the vector, and * refers to standard matrix/vector multiplication (not element-wise multiplication). The matrices are defined in terms of some differential quantity of the total loss function with respect to a provided list of tensors ("wrt_tensors"). For example, the Fisher associated with a log-prob loss w.r.t. the parameters. The 'vecs' argument to each method are lists of tensors that must be the size as the corresponding ones from "wrt_tensors". They represent the vector being multiplied. "factors" of the matrix M are defined as matrices B such that B*B^T = M. Methods that multiply by the factor B take a 'loss_inner_vecs' argument instead of 'vecs', which must be a list of tensors with shapes given by the corresponding XXX_inner_shapes property. Note that matrix-vector products are not normalized by the batch size, nor are any damping terms added to the results. These things can be easily applied externally, if desired. See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf and https://arxiv.org/abs/1412.1193 for more information about the generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector products. """ def __init__(self, layer_collection, wrt_tensors, colocate_gradients_with_ops=True): """Create a CurvatureMatrixVectorProductComputer object. Args: layer_collection: A LayerCollection object where the desired loss functions are registered (possibly with weighing factors). wrt_tensors: A list of Tensors to compute the differential quantities (defining the matrices) with respect to. See class description for more info. colocate_gradients_with_ops: Whether we should request gradients be colocated with their respective ops. (Default: True) """ self._layer_collection = layer_collection self._wrt_tensors = wrt_tensors self._colocate_gradients_with_ops = colocate_gradients_with_ops @property def _loss_colocation_ops(self): return self._layer_collection.loss_colocation_ops @property def _losses(self): return self._layer_collection.losses @property def _inputs_to_losses(self): return list(loss.inputs for loss in self._losses) @property def _inputs_to_losses_flat(self): return nest.flatten(self._inputs_to_losses) @property def _total_loss(self): return self._layer_collection.total_loss() def _get_loss_coeff(self, loss): return self._layer_collection.loss_coeffs[loss] # Jacobian multiplication functions: def _multiply_jacobian(self, vecs): """Multiply vecs by the Jacobian of losses.""" # We stop gradients at wrt_tensors to produce partial derivatives (which is # what we want for Jacobians). jacobian_vecs_flat = utils.fwd_gradients( self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, stop_gradients=self._wrt_tensors, colocate_gradients_with_ops=self._colocate_gradients_with_ops) return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) def _multiply_jacobian_transpose(self, loss_vecs): """Multiply vecs by the transpose Jacobian of losses.""" loss_vecs_flat = nest.flatten(loss_vecs) # We stop gradients at wrt_tensors to produce partial derivatives (which is # what we want for Jacobians). return tf.gradients( self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, stop_gradients=self._wrt_tensors, colocate_gradients_with_ops=self._colocate_gradients_with_ops) # Loss Fisher/GGN multiplication functions: def _multiply_across_losses(self, mult_func, vecs, coeff_mode="regular"): products = [] for loss, vec in zip(self._losses, vecs): with tf.colocate_with(self._loss_colocation_ops[loss]): if coeff_mode == "regular": multiplier = self._get_loss_coeff(loss) elif coeff_mode == "sqrt": multiplier = tf.sqrt(self._get_loss_coeff(loss)) val = mult_func(loss, vec) products.append(tf.cast(multiplier, dtype=val.dtype) * val) return tuple(products) def _multiply_loss_fisher(self, loss_vecs): """Multiply loss_vecs by Fisher of total loss.""" mult_func = lambda loss, vec: loss.multiply_fisher(vec) return self._multiply_across_losses(mult_func, loss_vecs) def _multiply_loss_fisher_factor(self, loss_inner_vecs): """Multiply loss_inner_vecs by factor of Fisher of total loss.""" mult_func = lambda loss, vec: loss.multiply_fisher_factor(vec) return self._multiply_across_losses(mult_func, loss_inner_vecs, coeff_mode="sqrt") def _multiply_loss_fisher_factor_transpose(self, loss_vecs): """Multiply loss_vecs by transpose factor of Fisher of total loss.""" mult_func = lambda loss, vec: loss.multiply_fisher_factor_transpose(vec) return self._multiply_across_losses(mult_func, loss_vecs, coeff_mode="sqrt") def _multiply_loss_ggn(self, loss_vecs): """Multiply loss_vecs by GGN of total loss.""" mult_func = lambda loss, vec: loss.multiply_ggn(vec) return self._multiply_across_losses(mult_func, loss_vecs) def _multiply_loss_ggn_factor(self, loss_inner_vecs): """Multiply loss_inner_vecs by factor of GGN of total loss.""" mult_func = lambda loss, vec: loss.multiply_ggn_factor(vec) return self._multiply_across_losses(mult_func, loss_inner_vecs, coeff_mode="sqrt") def _multiply_loss_ggn_factor_transpose(self, loss_vecs): """Multiply loss_vecs by transpose factor of GGN of total loss.""" mult_func = lambda loss, vec: loss.multiply_ggn_factor_transpose(vec) return self._multiply_across_losses(mult_func, loss_vecs, coeff_mode="sqrt") # Matrix-vector product functions (users should directly call these): def multiply_fisher(self, vecs): """Multiply vecs by Fisher of total loss.""" jacobian_vecs = self._multiply_jacobian(vecs) loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) def multiply_fisher_factor_transpose(self, vecs): """Multiply vecs by transpose of factor of Fisher of total loss.""" jacobian_vecs = self._multiply_jacobian(vecs) return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) def multiply_fisher_factor(self, loss_inner_vecs): """Multiply loss_inner_vecs by factor of Fisher of total loss.""" fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor( loss_inner_vecs) return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) def multiply_hessian(self, vecs): """Multiply vecs by Hessian of total loss.""" return tf.gradients( tf.gradients( self._total_loss, self._wrt_tensors, colocate_gradients_with_ops=self._colocate_gradients_with_ops), self._wrt_tensors, grad_ys=vecs, colocate_gradients_with_ops=self._colocate_gradients_with_ops) def multiply_ggn(self, vecs): """Multiply vecs by generalized Gauss-Newton of total loss.""" jacobian_vecs = self._multiply_jacobian(vecs) loss_ggn_jacobian_vecs = self._multiply_loss_ggn(jacobian_vecs) return self._multiply_jacobian_transpose(loss_ggn_jacobian_vecs) def multiply_ggn_factor_transpose(self, vecs): """Multiply vecs by transpose of factor of GGN of total loss.""" jacobian_vecs = self._multiply_jacobian(vecs) return self._multiply_loss_ggn_factor_transpose(jacobian_vecs) def multiply_ggn_factor(self, loss_inner_vecs): """Multiply loss_inner_vecs by factor of GGN of total loss.""" ggn_factor_transpose_vecs = ( self._multiply_loss_ggn_factor(loss_inner_vecs)) return self._multiply_jacobian_transpose(ggn_factor_transpose_vecs) # Shape properties for multiply_XXX_factor methods: @property def fisher_factor_inner_shapes(self): """Shapes required by multiply_fisher_factor.""" return tuple(loss.fisher_factor_inner_shape for loss in self._losses) @property def fisher_factor_inner_static_shapes(self): """Shapes required by multiply_fisher_factor.""" return tuple(loss.fisher_factor_inner_static_shape for loss in self._losses) @property def ggn_factor_inner_shapes(self): """Shapes required by multiply_generalized_gauss_newton_factor.""" return tuple(loss.ggn_factor_inner_shape for loss in self._losses) @property def ggn_factor_inner_static_shapes(self): """Shapes required by multiply_generalized_gauss_newton_factor.""" return tuple(loss.ggn_factor_inner_static_shape for loss in self._losses) ================================================ FILE: kfac/python/ops/estimator.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Defines the high-level Fisher estimator class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc # Dependency imports import numpy as np import six import tensorflow.compat.v1 as tf from tensorflow.python.util import nest from kfac.python.ops import placement from kfac.python.ops import utils # The linter is confused. # pylint: disable=abstract-class-instantiated def make_fisher_estimator(placement_strategy=None, **kwargs): """Creates Fisher estimator instances based on the placement strategy. For example if the `placement_strategy` is 'round_robin' then `FisherEstimatorRoundRobin` instance is returned. Args: placement_strategy: `string`, Strategy to be used for placing covariance variables, covariance ops and inverse ops. Check `placement.FisherEstimatorRoundRobin` for a concrete example. **kwargs: Arguments to be passed into `FisherEstimator` class initializer. Returns: An instance of class which inherits from `FisherEstimator` and the mixin which implements specific placement strategy. See, `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and `RoundRobinPlacementMixin`, as an example. Raises: ValueError: If the `placement_strategy` argument is not one of the recognized options. """ if placement_strategy in [None, "round_robin"]: return FisherEstimatorRoundRobin(**kwargs) elif placement_strategy == "replica_round_robin": return FisherEstimatorReplicaRoundRobin(**kwargs) else: raise ValueError( "Unimplemented vars and ops placement strategy : {}".format( placement_strategy)) # pylint: enable=abstract-class-instantiated @six.add_metaclass(abc.ABCMeta) class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher. This is an abstract base class which does not implement a strategy for placing covariance variables, covariance update ops and inverse update ops. The placement strategies are implemented in `placement.py`. See `FisherEstimatorRoundRobin` for example of a concrete subclass with a round-robin placement strategy. """ def __init__(self, variables, cov_ema_decay, damping, layer_collection, exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, name="FisherEstimator", compute_cholesky=False, compute_cholesky_inverse=False, compute_params_stats=False, batch_size=None): """Create a FisherEstimator object. Args: variables: A `list` of variables for which to estimate the Fisher. This must match the variables registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. damping: float or 0D Tensor. This quantity times the identity matrix is (approximately) added to the matrix being estimated. layer_collection: A LayerCollection object which holds for the Fisher blocks, Kronecker factors, and losses associated with the graph. exps: List of floats or ints. These represent the different matrix powers of the approximate Fisher that the FisherEstimator will be able to multiply vectors by. If the user asks for a matrix power other one of these (or 1, which is always supported), there will be a failure. (Default: (-1,)) estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be 'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN', 'exact', or 'exact_GGN'. (Default: 'gradients'). 'gradients' is the basic estimation approach from the original K-FAC paper. 'empirical' computes the 'empirical' Fisher information matrix (which uses the data's distribution for the targets, as opposed to the true Fisher which uses the model's distribution) and requires that each registered loss have specified targets. 'curvature_propagation' is a method which estimates the Fisher using self-products of random 1/-1 vectors times "half-factors" of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . 'exact' is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. Finally, 'curvature_prop_GGN' and 'exact_GGN' are analogous to 'curvature_prop' and 'exact', but estimate the Generalized Gauss-Newton matrix (GGN). colocate_gradients_with_ops: Whether we should request gradients be colocated with their respective ops. (Default: True) name: A string. A name given to this estimator, which is added to the variable scope when constructing variables and ops. (Default: "FisherEstimator") compute_cholesky: Bool. Whether or not the FisherEstimator will be able to multiply vectors by the Cholesky factor. (Default: False) compute_cholesky_inverse: Bool. Whether or not the FisherEstimator will be able to multiply vectors by the Cholesky factor inverse. (Default: False) compute_params_stats: Bool. If True, we compute the first order version of the statistics computed to estimate the Fisher/GGN. These correspond to the `variables` method in a one-to-one fashion. They are available via the `params_stats` property. When estimation_mode is 'empirical', this will correspond to the standard parameter gradient on the loss. (Default: False) batch_size: The size of the mini-batch. Only needed when `compute_params_stats` is True. Note that when using data parallelism where the model graph and optimizer are replicated across multiple devices, this should be the per-replica batch size. An example of this is sharded data on the TPU, where batch_size should be set to the total batch size divided by the number of shards. (Default: None) Raises: ValueError: If no losses have been registered with layer_collection. """ self._variables = variables self._cov_ema_decay = cov_ema_decay self._damping = damping self._estimation_mode = estimation_mode self._layer_collection = layer_collection self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, "curvature_prop": self._get_grads_lists_curvature_prop, "curvature_prop_GGN": self._get_grads_lists_curvature_prop, "exact": self._get_grads_lists_exact, "exact_GGN": self._get_grads_lists_exact } self._mat_type_table = { "gradients": "Fisher", "empirical": "Empirical_Fisher", "curvature_prop": "Fisher", "curvature_prop_GGN": "GGN", "exact": "Fisher", "exact_GGN": "GGN", } self._colocate_gradients_with_ops = colocate_gradients_with_ops self._exps = exps self._compute_cholesky = compute_cholesky self._compute_cholesky_inverse = compute_cholesky_inverse self._name = name self._compute_params_stats = compute_params_stats self._batch_size = batch_size if compute_params_stats and batch_size is None: raise ValueError("Batch size needs to be passed in when " "compute_params_stats is True.") self._finalized = False @property def variables(self): return self._variables @property def damping(self): return self._damping @property def blocks(self): """All registered FisherBlocks.""" return self.layers.get_blocks() @property def factors(self): """All registered FisherFactors.""" return self.layers.get_factors() @property def name(self): return self._name @property def layers(self): return self._layer_collection @property def mat_type(self): return self._mat_type_table[self._estimation_mode] @property def params_stats(self): return self._params_stats @abc.abstractmethod def _place_and_compute_transformation_thunks(self, thunks, params_list): """Computes transformation thunks with device placement. Device placement will be determined by the strategy asked for when this estimator was constructed. Args: thunks: A list of thunks to run. Must be in one to one correspondence with the `blocks` property. params_list: A list of the corresponding parameters. Must be in one to one correspondence with the `blocks` property. Returns: A list (in the same order) of the returned results of the thunks, with possible device placement applied. """ pass def _compute_transformation(self, vecs_and_vars, transform): """Computes a block-wise transformation of a list of vectors. Args: vecs_and_vars: List of (vector, variable) pairs. transform: A function of the form f(fb, vec), that returns the transformed vector, where vec is the vector to transform and fb is its corresponding block. Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) def make_thunk(fb, params): return lambda: transform(fb, vecs[params]) thunks = tuple(make_thunk(fb, params) for params, fb in self.layers.fisher_blocks.items()) params_list = tuple(params for params, _ in self.layers.fisher_blocks.items()) results = self._place_and_compute_transformation_thunks(thunks, params_list) trans_vecs = utils.SequenceDict() for params, result in zip(self.layers.fisher_blocks, results): trans_vecs[params] = result return [(trans_vecs[var], var) for _, var in vecs_and_vars] def multiply_inverse(self, vecs_and_vars): """Multiplies the vecs by the corresponding (damped) inverses of the blocks. Args: vecs_and_vars: List of (vector, variable) pairs. Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ return self.multiply_matpower(-1, vecs_and_vars) def multiply(self, vecs_and_vars): """Multiplies the vectors by the corresponding (damped) blocks. Args: vecs_and_vars: List of (vector, variable) pairs. Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ return self.multiply_matpower(1, vecs_and_vars) def multiply_matpower(self, exp, vecs_and_vars): """Multiplies the vecs by the corresponding matrix powers of the blocks. Args: exp: A float representing the power to raise the blocks by before multiplying it by the vector. vecs_and_vars: List of (vector, variable) pairs. Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) return self._compute_transformation(vecs_and_vars, fcn) def multiply_cholesky(self, vecs_and_vars, transpose=False): """Multiplies the vecs by the corresponding Cholesky factors. Args: vecs_and_vars: List of (vector, variable) pairs. transpose: Bool. If true the Cholesky factors are transposed before multiplying the vecs. (Default: False) Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) return self._compute_transformation(vecs_and_vars, fcn) def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): """Mults the vecs by the inverses of the corresponding Cholesky factors. Note: if you are using Cholesky inverse multiplication to sample from a matrix-variate Gaussian you will want to multiply by the transpose. Let L be the Cholesky factor of F and observe that L^-T * L^-1 = (L * L^T)^-1 = F^-1 . Thus we want to multiply by L^-T in order to sample from Gaussian with covariance F^-1. Args: vecs_and_vars: List of (vector, variable) pairs. transpose: Bool. If true the Cholesky factor inverses are transposed before multiplying the vecs. (Default: False) Returns: A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) return self._compute_transformation(vecs_and_vars, fcn) def _instantiate_factors(self): """Instantiates FisherFactors' variables. Raises: ValueError: If estimation_mode was improperly specified at construction. """ blocks = self.blocks tensors_to_compute_grads = [ block.tensors_to_compute_grads() for block in blocks ] if self._compute_params_stats: tensors_to_compute_grads = tensors_to_compute_grads + self.variables try: grads_lists = self._gradient_fns[self._estimation_mode]( tensors_to_compute_grads) except KeyError: raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) if any(grad is None for grad in nest.flatten(grads_lists)): tensors_flat = nest.flatten(tensors_to_compute_grads) grads_flat = nest.flatten(grads_lists) bad_tensors = tuple( tensor for tensor, grad in zip(tensors_flat, grads_flat) if grad is None) bad_string = "" for tensor in bad_tensors: bad_string += "\t{}\n".format(tensor) raise ValueError("It looks like you registered one of more tensors that " "the registered loss/losses don't depend on. (These " "returned None from tf.gradients.) The tensors were:" "\n\n" + bad_string) if self._compute_params_stats: idx = len(blocks) params_stats_unnorm = tuple(tf.add_n(grad_list) for grad_list in grads_lists[idx:]) scalar = 1. / tf.cast(self._batch_size, dtype=params_stats_unnorm[0].dtype) params_stats = utils.sprod(scalar, params_stats_unnorm) # batch_size should be the per-replica batch size and thus we do a # cross-replica mean instead of a sum here self._params_stats = tuple(utils.all_average(tensor) for tensor in params_stats) grads_lists = grads_lists[:idx] for grads_list, block in zip(grads_lists, blocks): block.instantiate_factors(grads_list, self.damping) def _register_matrix_functions(self): for block in self.blocks: for exp in self._exps: block.register_matpower(exp) if self._compute_cholesky: block.register_cholesky() if self._compute_cholesky_inverse: block.register_cholesky_inverse() def _finalize(self): if not self._finalized: self.layers.finalize() self.layers.check_registration(self.variables) self._instantiate_factors() self._register_matrix_functions() self._finalized = True def _check_batch_sizes(self, factor): """Checks that the batch size(s) for a factor matches the reference value.""" # Should make these messages use quote characters instead of parentheses # when the bug with quote character rendering in assertion messages is # fixed. See b/129476712 if self._batch_size is None: batch_size = self.factors[0].batch_size() string = ("Batch size {} for factor (" + factor.name + ") of type " + utils.cls_name(factor) + " did not match value {} used by " "factor (" + self.factors[0].name + ") of type " + utils.cls_name(self.factors[0]) + ".") else: batch_size = self._batch_size string = ("Batch size {} for factor (" + factor.name + ") of type " + utils.cls_name(factor) + " did not match value {} which was " "passed to optimizer/estimator.") factor_batch_size = factor.batch_size() if isinstance(batch_size, int) and isinstance(factor_batch_size, int): if factor_batch_size != batch_size: raise ValueError(string.format(factor_batch_size, batch_size)) return factor.check_partial_batch_sizes() else: message = string.format("(x)", "(y)") with tf.control_dependencies([factor.check_partial_batch_sizes()]): return tf.assert_equal(factor_batch_size, batch_size, message=message) def _create_ops_and_vars_thunks(self, scope=None): """Create thunks that make the ops and vars on demand. This function returns 4 lists of thunks: cov_variable_thunks, cov_update_thunks, inv_variable_thunks, and inv_update_thunks. The length of each list is the number of factors and the i-th element of each list corresponds to the i-th factor (given by the "factors" property). Note that the execution of these thunks must happen in a certain partial order. The i-th element of cov_variable_thunks must execute before the i-th element of cov_update_thunks (and also the i-th element of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks must execute before the i-th element of inv_update_thunks. TL;DR (oversimplified): Execute the thunks according to the order that they are returned. Args: scope: A string or None. If None it will be set to the name of this estimator (given by the name property). All thunks will execute inside of a variable scope of the given name. (Default: None) Returns: cov_variable_thunks: A list of thunks that make the cov variables. cov_update_thunks: A list of thunks that make the cov update ops. inv_variable_thunks: A list of thunks that make the inv variables. inv_update_thunks: A list of thunks that make the inv update ops. """ self._finalize() scope = self.name if scope is None else scope cov_variable_thunks = [ self._create_cov_variable_thunk(factor, scope) for factor in self.factors ] cov_update_thunks = [ self._create_cov_update_thunk(factor, scope) for factor in self.factors ] inv_variable_thunks = [ self._create_inv_variable_thunk(factor, scope) for factor in self.factors ] inv_update_thunks = [ self._create_inv_update_thunk(factor, scope) for factor in self.factors ] return (cov_variable_thunks, cov_update_thunks, inv_variable_thunks, inv_update_thunks) @abc.abstractmethod def create_ops_and_vars_thunks(self, scope=None): """Create thunks that make the ops and vars on demand with device placement. This function returns 4 lists of thunks: cov_variable_thunks, cov_update_thunks, inv_variable_thunks, and inv_update_thunks. The length of each list is the number of factors and the i-th element of each list corresponds to the i-th factor (given by the "factors" property). Note that the execution of these thunks must happen in a certain partial order. The i-th element of cov_variable_thunks must execute before the i-th element of cov_update_thunks (and also the i-th element of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks must execute before the i-th element of inv_update_thunks. TL;DR (oversimplified): Execute the thunks according to the order that they are returned. Device placement will be determined by the strategy asked for when this estimator was constructed. Args: scope: A string or None. If None it will be set to the name of this estimator (given by the name property). All thunks will execute inside of a variable scope of the given name. (Default: None) Returns: cov_variable_thunks: A list of thunks that make the cov variables. cov_update_thunks: A list of thunks that make the cov update ops. inv_variable_thunks: A list of thunks that make the inv variables. inv_update_thunks: A list of thunks that make the inv update ops. """ pass def make_vars_and_create_op_thunks(self, scope=None): """Make vars and create op thunks with device placement. Similar to create_ops_and_vars_thunks but actually makes the variables instead of returning thunks that make them. Device placement will be determined by the strategy asked for when this estimator was constructed. Args: scope: A string or None. If None it will be set to the name of this estimator (given by the name property). All variables will be created, and all thunks will execute, inside of a variable scope of the given name. (Default: None) Returns: cov_update_thunks: List of cov update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. inv_update_thunks: List of inv update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. """ (cov_variable_thunks, cov_update_thunks, inv_variable_thunks, inv_update_thunks) = self.create_ops_and_vars_thunks(scope=scope) for thunk in cov_variable_thunks: thunk() for thunk in inv_variable_thunks: thunk() return cov_update_thunks, inv_update_thunks def get_cov_vars(self): """Returns all covariance variables associated with each Fisher factor. Note the returned list also includes additional factor specific covariance variables. Returns: List of list. The number of inner lists is equal to number of factors. And each inner list contains all covariance variables for that factor. """ return tuple(factor.get_cov_vars() for factor in self.factors) def get_inv_vars(self): """Returns all covariance variables associated with each Fisher factor. Note the returned list also includes additional factor specific covariance variables. Returns: List of list. The number of inner lists is equal to number of factors. And each inner list contains all inverse computation related variables for that factor. """ return tuple(factor.get_inv_vars() for factor in self.factors) def _create_cov_variable_thunk(self, factor, scope): """Constructs a covariance variable thunk for a single FisherFactor.""" def thunk(): with tf.variable_scope(scope): return factor.instantiate_cov_variables() return thunk def _create_cov_update_thunk(self, factor, scope): """Constructs a covariance update thunk for a single FisherFactor.""" def thunk(should_decay=True): if isinstance(should_decay, bool): ema_decay = self._cov_ema_decay if should_decay else 1.0 else: ema_decay = tf.cond(should_decay, lambda: self._cov_ema_decay, lambda: 1.0) ema_weight = 1.0 with tf.variable_scope(scope): with tf.control_dependencies([self._check_batch_sizes(factor)]): return factor.make_covariance_update_op(ema_decay, ema_weight) return thunk def _create_inv_variable_thunk(self, factor, scope): """Constructs a inverse variable thunk for a single FisherFactor.""" def thunk(): with tf.variable_scope(scope): return factor.instantiate_inv_variables() return thunk def _create_inv_update_thunk(self, factor, scope): """Constructs an inverse update thunk for a single FisherFactor.""" def thunk(): with tf.variable_scope(scope): return tf.group(factor.make_inverse_update_ops()) return thunk def _get_grads_lists_gradients(self, tensors): # Passing in a list of loss values is better than passing in the sum as # the latter creates unnecessary ops on the default device grads_flat = tf.gradients( self.layers.eval_losses(target_mode="sample", coeff_mode="sqrt"), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): # Passing in a list of loss values is better than passing in the sum as # the latter creates unnessesary ops on the default device grads_flat = tf.gradients( self.layers.eval_losses(target_mode="data", coeff_mode="regular"), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_transformed_random_signs(self): if self.mat_type == "Fisher": mult_func = lambda loss, index: loss.multiply_fisher_factor(index) inner_shape_func = lambda loss: loss.fisher_factor_inner_shape elif self.mat_type == "GGN": mult_func = lambda loss, index: loss.multiply_ggn_factor(index) inner_shape_func = lambda loss: loss.ggn_factor_inner_shape transformed_random_signs = [] for loss in self.layers.losses: with tf.colocate_with(self.layers.loss_colocation_ops[loss]): value = mult_func(loss, utils.generate_random_signs(inner_shape_func(loss), dtype=loss.dtype)) coeff = tf.cast(self.layers.loss_coeffs[loss], dtype=value.dtype) transformed_random_signs.append(tf.sqrt(coeff) * value) return transformed_random_signs def _get_grads_lists_curvature_prop(self, tensors): loss_inputs = list(loss.inputs for loss in self.layers.losses) transformed_random_signs = self._get_transformed_random_signs() grads_flat = tf.gradients( nest.flatten(loss_inputs), nest.flatten(tensors), grad_ys=nest.flatten(transformed_random_signs), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_exact(self, tensors): if self.mat_type == "Fisher": # pylint: disable=g-long-lambda mult_func = (lambda loss, index: loss.multiply_fisher_factor_replicated_one_hot(index)) inner_shape_func = lambda loss: loss.fisher_factor_inner_static_shape elif self.mat_type == "GGN": # pylint: disable=g-long-lambda mult_func = (lambda loss, index: loss.multiply_ggn_factor_replicated_one_hot(index)) inner_shape_func = lambda loss: loss.fisher_ggn_inner_static_shape # Loop over all coordinates of all losses. grads_all = [] for loss in self.layers.losses: with tf.colocate_with(self.layers.loss_colocation_ops[loss]): for index in np.ndindex(*inner_shape_func(loss)[1:]): value = mult_func(loss, index) coeff = tf.cast(self.layers.loss_coeffs[loss], dtype=value.dtype) transformed_one_hot = tf.sqrt(coeff) * value grads_flat = tf.gradients( loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot, colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return tuple(zip(*grads_all)) class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, FisherEstimator): """FisherEstimator which provides round robin device placement strategy.""" pass class FisherEstimatorReplicaRoundRobin( placement.ReplicaRoundRobinPlacementMixin, FisherEstimator): """FisherEstimator which provides round robin replica placement strategy.""" pass ================================================ FILE: kfac/python/ops/fisher_blocks.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """FisherBlock definitions. This library contains classes for estimating blocks in a model's Fisher Information matrix. Suppose one has a model that parameterizes a posterior distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its Fisher Information matrix is given by, F(params) = E[ v(x, y, params) v(x, y, params)^T ] where, v(x, y, params) = (d / d params) log p(y | x, params) and the expectation is taken with respect to the data's distribution for 'x' and the model's posterior distribution for 'y', x ~ p(x) y ~ p(y | x, params) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc # Dependency imports import six import tensorflow.compat.v1 as tf from tensorflow.python.util import nest from kfac.python.ops import fisher_factors from kfac.python.ops import utils # For blocks corresponding to convolutional layers, or any type of block where # the parameters can be thought of as being replicated in time or space, # we want to adjust the scale of the damping by # damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 # Methods for adjusting damping for FisherBlocks. See # compute_pi_adjusted_damping() for details. PI_OFF_NAME = "off" PI_TRACENORM_NAME = "tracenorm" PI_TYPE = PI_TRACENORM_NAME def set_global_constants(normalize_damping_power=None, pi_type=None): """Sets various global constants used by the classes in this module.""" global NORMALIZE_DAMPING_POWER global PI_TYPE if normalize_damping_power is not None: NORMALIZE_DAMPING_POWER = normalize_damping_power if pi_type is not None: PI_TYPE = pi_type def normalize_damping(damping, num_replications): """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" if NORMALIZE_DAMPING_POWER: return damping / (num_replications ** NORMALIZE_DAMPING_POWER) return damping def compute_pi_tracenorm(left_cov, right_cov): """Computes the scalar constant pi for Tikhonov regularization/damping. pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: left_cov: A LinearOperator object. The left Kronecker factor "covariance". right_cov: A LinearOperator object. The right Kronecker factor "covariance". Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. left_norm = left_cov.trace() * int(right_cov.domain_dimension) right_norm = right_cov.trace() * int(left_cov.domain_dimension) def normal_case(): assert_positive = tf.assert_positive( right_norm, message="PI computation, trace of right cov matrix should be positive. " "Note that most likely cause of this error is that the optimizer " "diverged (e.g. due to bad hyperparameters).") with tf.control_dependencies([assert_positive]): return tf.sqrt(left_norm / right_norm) def zero_case(): return tf.constant(1.0, dtype=left_norm.dtype) return tf.cond(tf.equal(left_norm * right_norm, 0.0), zero_case, normal_case) def compute_pi_adjusted_damping(left_cov, right_cov, damping): if PI_TYPE == PI_TRACENORM_NAME: pi = compute_pi_tracenorm(left_cov, right_cov) damping = tf.cast(damping, dtype=pi.dtype) return (damping * pi, damping / pi) elif PI_TYPE == PI_OFF_NAME: return (damping, damping) class PackagedFunc(object): """A Python thunk with a stable ID. Enables stable names for lambdas. """ def __init__(self, func, func_id): """Initializes PackagedFunc. Args: func: a zero-arg Python function. func_id: a hashable, function that produces a hashable, or a list/tuple thereof. """ self._func = func func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) self._func_id = func_id def __call__(self): return self._func() @property def func_id(self): """A hashable identifier for this function.""" return tuple(elt() if callable(elt) else elt for elt in self._func_id) def _package_func(func, func_id): return PackagedFunc(func, func_id) @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. Subclasses must implement register_matpower, multiply_matpower, instantiate_factors, tensors_to_compute_grads, and num_registered_towers methods. """ def __init__(self, layer_collection): self._layer_collection = layer_collection @abc.abstractmethod def instantiate_factors(self, grads_list, damping): """Creates and registers the component factors of this Fisher block. Args: grads_list: A list gradients (each a Tensor or tuple of Tensors) with respect to the tensors returned by tensors_to_compute_grads() that are to be used to estimate the block. damping: The damping factor (float or Tensor). """ pass @abc.abstractmethod def register_matpower(self, exp): """Registers a matrix power to be computed by the block. Args: exp: A float representing the power to raise the block by. """ pass @abc.abstractmethod def register_cholesky(self): """Registers a Cholesky factor to be computed by the block.""" pass @abc.abstractmethod def register_cholesky_inverse(self): """Registers an inverse Cholesky factor to be computed by the block.""" pass def register_inverse(self): """Registers a matrix inverse to be computed by the block.""" self.register_matpower(-1) @abc.abstractmethod def multiply_matpower(self, vector, exp): """Multiplies the vector by the (damped) matrix-power of the block. Args: vector: The vector (a Tensor or tuple of Tensors) to be multiplied. exp: A float representing the power to raise the block by before multiplying it by the vector. Returns: The vector left-multiplied by the (damped) matrix-power of the block. """ pass def multiply_inverse(self, vector): """Multiplies the vector by the (damped) inverse of the block. Args: vector: The vector (a Tensor or tuple of Tensors) to be multiplied. Returns: The vector left-multiplied by the (damped) inverse of the block. """ return self.multiply_matpower(vector, -1) def multiply(self, vector): """Multiplies the vector by the (damped) block. Args: vector: The vector (a Tensor or tuple of Tensors) to be multiplied. Returns: The vector left-multiplied by the (damped) block. """ return self.multiply_matpower(vector, 1) @abc.abstractmethod def multiply_cholesky(self, vector, transpose=False): """Multiplies the vector by the (damped) Cholesky-factor of the block. Args: vector: The vector (a Tensor or tuple of Tensors) to be multiplied. transpose: Bool. If true the Cholesky factor is transposed before multiplying the vector. (Default: False) Returns: The vector left-multiplied by the (damped) Cholesky-factor of the block. """ pass @abc.abstractmethod def multiply_cholesky_inverse(self, vector, transpose=False): """Multiplies vector by the (damped) inverse Cholesky-factor of the block. Args: vector: The vector (a Tensor or tuple of Tensors) to be multiplied. transpose: Bool. If true the Cholesky factor inverse is transposed before multiplying the vector. (Default: False) Returns: Vector left-multiplied by (damped) inverse Cholesky-factor of the block. """ pass @abc.abstractmethod def tensors_to_compute_grads(self): """Returns the Tensor(s) with respect to which this FisherBlock needs grads. """ pass @abc.abstractproperty def num_registered_towers(self): """Number of towers registered for this FisherBlock. Typically equal to the number of towers in a multi-tower setup. """ pass @six.add_metaclass(abc.ABCMeta) class FullFB(FisherBlock): """Base class for blocks that use full matrix representations (no approx).""" def register_matpower(self, exp): self._factor.register_matpower(exp, self._damping_func) def register_cholesky(self): self._factor.register_cholesky(self._damping_func) def register_cholesky_inverse(self): self._factor.register_cholesky_inverse(self._damping_func) def _multiply_matrix(self, matrix, vector, transpose=False): vector_flat = utils.tensors_to_column(vector) out_flat = matrix.matmul(vector_flat, adjoint=transpose) return utils.column_to_tensors(vector, out_flat) def multiply_matpower(self, vector, exp): matrix = self._factor.get_matpower(exp, self._damping_func) return self._multiply_matrix(matrix, vector) def multiply_cholesky(self, vector, transpose=False): matrix = self._factor.get_cholesky(self._damping_func) return self._multiply_matrix(matrix, vector, transpose=transpose) def multiply_cholesky_inverse(self, vector, transpose=False): matrix = self._factor.get_cholesky_inverse(self._damping_func) return self._multiply_matrix(matrix, vector, transpose=transpose) def full_fisher_block(self): """Explicitly constructs the full Fisher block.""" return self._factor.get_cov_as_linear_operator().to_dense() class NaiveFullFB(FullFB): """FisherBlock using a full matrix estimate (no approximations). NaiveFullFB uses a full matrix estimate (no approximations), and should only ever be used for very low dimensional parameters. Note that this uses the naive "square the sum estimator", and so is applicable to any type of parameter in principle, but has very high variance. """ def __init__(self, layer_collection, params): """Creates a NaiveFullFB block. Args: layer_collection: The LayerCollection object which owns this block. params: The parameters of this layer (Tensor or tuple of Tensors). """ self._batch_sizes = [] self._params = params super(NaiveFullFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): self._damping_func = _package_func(lambda: damping, (damping,)) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveFullFactor, (grads_list, self._batch_size)) def tensors_to_compute_grads(self): return self._params def register_additional_tower(self, batch_size): """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. """ self._batch_sizes.append(batch_size) @property def num_registered_towers(self): return len(self._batch_sizes) @property def _batch_size(self): return tf.reduce_sum(self._batch_sizes) @six.add_metaclass(abc.ABCMeta) class DiagonalFB(FisherBlock): """A base class for FisherBlocks that use diagonal approximations.""" def register_matpower(self, exp): # Not needed for this. Matrix powers are computed on demand in the # diagonal case pass def register_cholesky(self): # Not needed for this. Cholesky's are computed on demand in the # diagonal case pass def register_cholesky_inverse(self): # Not needed for this. Cholesky inverses's are computed on demand in the # diagonal case pass def _multiply_matrix(self, matrix, vector): vector_flat = utils.tensors_to_column(vector) out_flat = matrix.matmul(vector_flat) return utils.column_to_tensors(vector, out_flat) def multiply_matpower(self, vector, exp): matrix = self._factor.get_matpower(exp, self._damping_func) return self._multiply_matrix(matrix, vector) def multiply_cholesky(self, vector, transpose=False): matrix = self._factor.get_cholesky(self._damping_func) return self._multiply_matrix(matrix, vector) def multiply_cholesky_inverse(self, vector, transpose=False): matrix = self._factor.get_cholesky_inverse(self._damping_func) return self._multiply_matrix(matrix, vector) def full_fisher_block(self): return self._factor.get_cov_as_linear_operator().to_dense() class NaiveDiagonalFB(DiagonalFB): """FisherBlock using a diagonal matrix approximation. This type of approximation is generically applicable but quite primitive. Note that this uses the naive "square the sum estimator", and so is applicable to any type of parameter in principle, but has very high variance. """ def __init__(self, layer_collection, params): """Creates a NaiveDiagonalFB block. Args: layer_collection: The LayerCollection object which owns this block. params: The parameters of this layer (must be a single Tensor). """ self._params = params self._batch_sizes = [] super(NaiveDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): self._damping_func = _package_func(lambda: damping, (damping,)) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) def tensors_to_compute_grads(self): return self._params def register_additional_tower(self, batch_size): """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. """ self._batch_sizes.append(batch_size) @property def num_registered_towers(self): return len(self._batch_sizes) @property def _batch_size(self): return tf.reduce_sum(self._batch_sizes) class InputOutputMultiTower(object): """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" def __init__(self, *args, **kwargs): self.__inputs = [] self.__outputs = [] super(InputOutputMultiTower, self).__init__(*args, **kwargs) def _process_data(self, grads_list): """Process data into the format used by the factors. This function takes inputs and grads_lists data and processes it into one of the formats expected by the FisherFactor classes (depending on the value of the global configuration variable TOWER_STRATEGY). The initial format of self._inputs is expected to be a list of Tensors over towers. Similarly grads_lists is expected to be a list over sources of such lists. If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single tensor (represented as a PartitionedTensor object) equal to the concatenation (across towers) of all of the elements of self._inputs. And similarly grads_list is formatted into a tuple (over sources) of such tensors (also represented as PartitionedTensors). If TOWER_STRATEGY is "separate", formatting of inputs and grads_list remains unchanged from the initial format (although possibly converting from lists into tuples). Args: grads_list: grads_list in its initial format (see above). Returns: inputs: self._inputs transformed into the appropriate format (see above). grads_list: grads_list transformed into the appropriate format (see above). Raises: ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". """ inputs = self._inputs # inputs is a list over towers of Tensors # grads_list is a list of list with the first index being sources and the # second being towers. if fisher_factors.TOWER_STRATEGY == "concat": # Merge towers together into a PartitionedTensor. We package it in # a singleton tuple since the factors will expect a list over towers inputs = (utils.PartitionedTensor(inputs),) # Do the same for grads_list but preserve leading sources dimension grads_list = tuple((utils.PartitionedTensor(grads),) for grads in grads_list) elif fisher_factors.TOWER_STRATEGY == "separate": inputs = tuple(inputs) grads_list = tuple(grads_list) else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") return inputs, grads_list def tensors_to_compute_grads(self): """Tensors to compute derivative of loss with respect to.""" return tuple(self._outputs) def register_additional_tower(self, inputs, outputs): self._inputs.append(inputs) self._outputs.append(outputs) @property def num_registered_towers(self): result = len(self._inputs) assert result == len(self._outputs) return result @property def _inputs(self): return self.__inputs @property def _outputs(self): return self.__outputs class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator. Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] = E[ v(x, y, params)[i] ^ 2 ] Consider fully connected layer in this model with (unshared) weight matrix 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's', v(x, y, w) = vec( a (d loss / d s)^T ) This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. """ def __init__(self, layer_collection, has_bias=False): """Creates a FullyConnectedDiagonalFB block. Args: layer_collection: The LayerCollection object which owns this block. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's weights. (Default: False) """ self._has_bias = has_bias super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedDiagonalFactor, (inputs, grads_list, self._has_bias)) self._damping_func = _package_func(lambda: damping, (damping,)) class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator. Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] = E[ v(x, y, params)[i] ^ 2 ] Consider a convolutional layer in this model with (unshared) filter matrix 'w'. For an example image 'x' that produces layer inputs 'a' and output preactivations 's', v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) where 'loc' is a single (x, y) location in an image. This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. """ def __init__(self, layer_collection, params, strides, padding, data_format=None, dilations=None, patch_mask=None): """Creates a ConvDiagonalFB block. Args: layer_collection: The LayerCollection object which owns this block. params: The parameters (Tensor or tuple of Tensors) of this layer. If kernel alone, a Tensor of shape [kernel_height, kernel_width, in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (e.g. "SAME"). data_format: str or None. Format of input data. dilations: List of 4 ints or None. Rate for dilation along all dimensions. patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed. (Default: None) Raises: ValueError: if strides is not length-4. ValueError: if dilations is not length-4. ValueError: if channel is not last dimension. """ if len(strides) != 4: raise ValueError("strides must contain 4 numbers.") if dilations is None: dilations = [1, 1, 1, 1] if len(dilations) != 4: raise ValueError("dilations must contain 4 numbers.") if not utils.is_data_format_channel_last(data_format): raise ValueError("data_format must be channels-last.") self._strides = maybe_tuple(strides) self._padding = padding self._data_format = data_format self._dilations = maybe_tuple(dilations) self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) if len(self._filter_shape) != 4: raise ValueError( "Convolution filter must be of shape" " [filter_height, filter_width, in_channels, out_channels].") self._patch_mask = patch_mask super(ConvDiagonalFB, self).__init__(layer_collection) @property def _factor_implementation(self): return fisher_factors.ConvDiagonalFactor def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(), list(self._filter_shape), self._strides, self._padding) self._factor = self._layer_collection.make_or_get_factor( self._factor_implementation, (inputs, grads_list, self._filter_shape, self._strides, self._padding, self._data_format, self._dilations, self._has_bias, self._patch_mask)) def damping_func(): return self._num_locations * normalize_damping(damping, self._num_locations) damping_id = (self._num_locations, "mult", "normalize_damping", damping, self._num_locations) self._damping_func = _package_func(damping_func, damping_id) class ScaleAndShiftFullFB(InputOutputMultiTower, FullFB): """A FisherBlock class for scale and shift ops that uses no approximations. This class estimates the same thing that NaiveFullFB would (when applied to the scale and shift params), but with a lower variance estimator. In particular it uses a "sum the squares estimator", and thus the variance will shrink as 1/batch_size. """ def __init__(self, layer_collection, broadcast_dims_scale, broadcast_dims_shift=None, has_shift=True): """Creates a ScaleAndShiftFullFB block. Args: layer_collection: The LayerCollection object which owns this block. broadcast_dims_scale: A list of dimension indices that are broadcast along during the scale operation. Does not include batch dimension. broadcast_dims_shift: A list of dimension indices that are broadcast along during the shift operation. Does not include batch dimension. has_shift: bool. If True, estimates Fisher with respect to a shift parameter as well the scale parameter (which is always included). """ self._broadcast_dims_scale = broadcast_dims_scale self._broadcast_dims_shift = broadcast_dims_shift self._has_shift = has_shift super(ScaleAndShiftFullFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ScaleAndShiftFullFactor, (inputs, grads_list, self._broadcast_dims_scale, self._broadcast_dims_shift, self._has_shift)) self._damping_func = _package_func(lambda: damping, (damping,)) class ScaleAndShiftDiagonalFB(InputOutputMultiTower, DiagonalFB): """A FisherBlock class for scale and shift ops that uses a diagonal approx. This class estimates the same thing that NaiveDiagonalFB would (when applied to the scale and shift params), but with a lower variance estimator. In particular it uses a "sum the squares estimator", and thus the variance will shrink as 1/batch_size. """ def __init__(self, layer_collection, broadcast_dims_scale, broadcast_dims_shift=None, has_shift=True): """Creates a ScaleAndShiftDiagonalFB block. Args: layer_collection: The LayerCollection object which owns this block. broadcast_dims_scale: A list of dimension indices that are broadcast along during the scale operation. Does not include batch dimension. broadcast_dims_shift: A list of dimension indices that are broadcast along during the shift operation. Does not include batch dimension. has_shift: bool. If True, estimates Fisher with respect to a shift parameter as well the scale parameter (which is always included). """ self._broadcast_dims_scale = broadcast_dims_scale self._broadcast_dims_shift = broadcast_dims_shift self._has_shift = has_shift super(ScaleAndShiftDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ScaleAndShiftDiagonalFactor, (inputs, grads_list, self._broadcast_dims_scale, self._broadcast_dims_shift, self._has_shift)) self._damping_func = _package_func(lambda: damping, (damping,)) class KroneckerProductFB(FisherBlock): """A base class for blocks with separate input and output Kronecker factors. The Fisher block is approximated as a Kronecker product of the input and output factors. """ def _setup_damping(self, damping, normalization=None): """Makes functions that compute the damping values for both factors.""" def compute_damping(): if normalization is not None: maybe_normalized_damping = normalize_damping(damping, normalization) else: maybe_normalized_damping = damping return compute_pi_adjusted_damping( self._input_factor.get_cov_as_linear_operator(), self._output_factor.get_cov_as_linear_operator(), maybe_normalized_damping**0.5) if normalization is not None: damping_id = ("compute_pi_adjusted_damping", "cov", self._input_factor.name, "cov", self._output_factor.name, "normalize_damping", damping, normalization, "power", 0.5) else: damping_id = ("compute_pi_adjusted_damping", "cov", self._input_factor.name, "cov", self._output_factor.name, damping, "power", 0.5) self._input_damping_func = _package_func(lambda: compute_damping()[0], damping_id + ("ref", 0)) self._output_damping_func = _package_func(lambda: compute_damping()[1], damping_id + ("ref", 1)) # Also store the damping op for access to the effective damping later on, # such as when writing to summary. if normalization is not None: self._damping = normalize_damping(damping, normalization) else: self._damping = damping def register_matpower(self, exp): self._input_factor.register_matpower(exp, self._input_damping_func) self._output_factor.register_matpower(exp, self._output_damping_func) def register_cholesky(self): self._input_factor.register_cholesky(self._input_damping_func) self._output_factor.register_cholesky(self._output_damping_func) def register_cholesky_inverse(self): self._input_factor.register_cholesky_inverse(self._input_damping_func) self._output_factor.register_cholesky_inverse(self._output_damping_func) @property def damping(self): """A copy of the damping op. This is not used (and should never be used) in KFAC computations. A valid usage of this property could be to write damping values to the summary. Returns: 0-D Tensor. """ return self._damping @property def input_factor(self): return self._input_factor @property def output_factor(self): return self._output_factor @property def _renorm_coeff(self): """Kronecker factor multiplier coefficient. If this FisherBlock is represented as 'FB = c * kron(left, right)', then this is 'c'. Returns: 0-D Tensor. """ return 1.0 def _multiply_factored_matrix(self, left_factor, right_factor, vector, extra_scale=1.0, transpose_left=False, transpose_right=False): """Multiplies a factored matrix.""" reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_out = right_factor.matmul_right(reshaped_vector, adjoint=transpose_right) reshaped_out = left_factor.matmul(reshaped_out, adjoint=transpose_left) if extra_scale != 1.0: reshaped_out = tf.scalar_mul(extra_scale, reshaped_out) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply_matpower(self, vector, exp): left_factor = self._input_factor.get_matpower( exp, self._input_damping_func) right_factor = self._output_factor.get_matpower( exp, self._output_damping_func) extra_scale = float(self._renorm_coeff)**exp return self._multiply_factored_matrix(left_factor, right_factor, vector, extra_scale=extra_scale) def multiply_cholesky(self, vector, transpose=False): left_factor = self._input_factor.get_cholesky(self._input_damping_func) right_factor = self._output_factor.get_cholesky(self._output_damping_func) extra_scale = float(self._renorm_coeff)**0.5 return self._multiply_factored_matrix(left_factor, right_factor, vector, extra_scale=extra_scale, transpose_left=transpose, transpose_right=not transpose) def multiply_cholesky_inverse(self, vector, transpose=False): left_factor = self._input_factor.get_cholesky_inverse( self._input_damping_func) right_factor = self._output_factor.get_cholesky_inverse( self._output_damping_func) extra_scale = float(self._renorm_coeff)**-0.5 return self._multiply_factored_matrix(left_factor, right_factor, vector, extra_scale=extra_scale, transpose_left=transpose, transpose_right=not transpose) def full_fisher_block(self): """Explicitly constructs the full Fisher block. Used for testing purposes. (In general, the result may be very large.) Returns: The full Fisher block. """ left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() return self._renorm_coeff * utils.kronecker_product(left_factor, right_factor) class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. This uses the Kronecker-factorized approximation from the original K-FAC paper (https://arxiv.org/abs/1503.05671) """ def __init__(self, layer_collection, has_bias=False, diagonal_approx_for_input=False, diagonal_approx_for_output=False): """Creates a FullyConnectedKFACBasicFB block. Args: layer_collection: The LayerCollection object which owns this block. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's weights. (Default: False) diagonal_approx_for_input: Whether to use diagonal approximation for the input Kronecker factor. (Default: False) diagonal_approx_for_output: Whether to use diagonal approximation for the output Kronecker factor. (Default: False) """ self._has_bias = has_bias self._diagonal_approx_for_input = diagonal_approx_for_input self._diagonal_approx_for_output = diagonal_approx_for_output super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): """Instantiate Kronecker Factors for this FisherBlock. Args: grads_list: List of list of Tensors. grads_list[i][j] is the gradient of the loss with respect to 'outputs' from source 'i' and tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ inputs, grads_list = self._process_data(grads_list) if self._diagonal_approx_for_input: self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.DiagonalKroneckerFactor, ((inputs,), self._has_bias)) else: self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedKroneckerFactor, ((inputs,), self._has_bias)) if self._diagonal_approx_for_output: self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.DiagonalKroneckerFactor, (grads_list,)) else: self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) self._setup_damping(damping) class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): """FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. Consider a convolutional layer in this model with (unshared) filter matrix 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates, F(w) = #locations * kronecker(E[flat(a) flat(a)^T], E[flat(ds) flat(ds)^T]) where ds = (d / ds) log p(y | x, w) #locations = number of (x, y) locations where 'w' is applied. where the expectation is taken over all examples and locations and flat() concatenates an array's leading dimensions. See equation 23 in https://arxiv.org/abs/1602.01407 for details. """ def __init__(self, layer_collection, params, padding, strides=None, dilation_rate=None, data_format=None, extract_patches_fn=None, sub_sample_inputs=None, sub_sample_patches=None, use_sua_approx_for_input_factor=False, patch_mask=None): """Creates a ConvKFCBasicFB block. Args: layer_collection: The LayerCollection object which owns this block. params: The parameters (Tensor or tuple of Tensors) of this layer. If kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. padding: str. Padding method. strides: List of ints or None. Contains [..spatial_filter_strides..] if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_filter_strides, 1]. dilation_rate: List of ints or None. Rate for dilation along each spatial dimension if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. data_format: str or None. Format of input data. extract_patches_fn: str or None. Name of function that extracts image patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches". sub_sample_inputs: `bool`. If True, then subsample the inputs from which the image patches are extracted. (Default: None) sub_sample_patches: `bool`, If `True` then subsample the extracted patches. (Default: None) use_sua_approx_for_input_factor: `bool`, If `True` then use `ConvInputSUAKroneckerFactor` for input factor. Otherwise use `ConvInputKroneckerFactor`. (Default: None) patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed in the input factor. (Default: None) """ self._padding = padding self._strides = maybe_tuple(strides) self._dilation_rate = maybe_tuple(dilation_rate) self._data_format = data_format self._extract_patches_fn = extract_patches_fn self._has_bias = isinstance(params, (tuple, list)) self._use_sua_approx_for_input_factor = use_sua_approx_for_input_factor fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) self._sub_sample_inputs = sub_sample_inputs self._sub_sample_patches = sub_sample_patches self._patch_mask = patch_mask super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(), list(self._filter_shape), self._strides, self._padding) if self._use_sua_approx_for_input_factor: self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputSUAKroneckerFactor, (inputs, self._filter_shape, self._has_bias)) else: self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, (inputs, self._filter_shape, self._padding, self._strides, self._dilation_rate, self._data_format, self._extract_patches_fn, self._has_bias, self._sub_sample_inputs, self._sub_sample_patches, self._patch_mask)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) self._setup_damping(damping, normalization=self._num_locations) @property def _renorm_coeff(self): return self._num_locations class DepthwiseConvDiagonalFB(ConvDiagonalFB): """FisherBlock for depthwise_conv2d(). Equivalent to ConvDiagonalFB applied to each input channel in isolation. """ def __init__(self, layer_collection, params, strides, padding, rate=None, data_format=None): """Creates a DepthwiseConvKFCBasicFB block. Args: layer_collection: The LayerCollection object which owns this block. params: Tensor of shape [filter_height, filter_width, in_channels, channel_multiplier]. strides: List of 4 ints. Strides along all dimensions. padding: str. Padding method. rate: List of 4 ints or None. Rate for dilation along all dimensions. data_format: str or None. Format of input data. Raises: NotImplementedError: If parameters contains bias. ValueError: If filter is not 4-D. ValueError: If strides is not length-4. ValueError: If rates is not length-2. ValueError: If channels are not last dimension. """ if isinstance(params, (tuple, list)): raise NotImplementedError("Bias not yet supported.") if params.shape.ndims != 4: raise ValueError("Filter must be 4-D.") if len(strides) != 4: raise ValueError("strides must account for 4 dimensions.") if rate is not None: if len(rate) != 2: raise ValueError("rate must only account for spatial dimensions.") rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. if not utils.is_data_format_channel_last(data_format): raise ValueError("data_format must be channels-last.") super(DepthwiseConvDiagonalFB, self).__init__( layer_collection=layer_collection, params=params, strides=strides, padding=padding, dilations=rate, data_format=data_format) # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). filter_height, filter_width, in_channels, channel_multiplier = ( params.shape.as_list()) self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) def _multiply_matrix(self, matrix, vector): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) conv2d_result = super( DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): """FisherBlock for depthwise_conv2d(). Equivalent to ConvKFCBasicFB applied to each input channel in isolation. """ def __init__(self, layer_collection, params, strides, padding, rate=None, data_format=None): """Creates a DepthwiseConvKFCBasicFB block. Args: layer_collection: The LayerCollection object which owns this block. params: Tensor of shape [filter_height, filter_width, in_channels, channel_multiplier]. strides: List of 4 ints. Strides along all dimensions. padding: str. Padding method. rate: List of 4 ints or None. Rate for dilation along all dimensions. data_format: str or None. Format of input data. Raises: NotImplementedError: If parameters contains bias. ValueError: If filter is not 4-D. ValueError: If strides is not length-4. ValueError: If rates is not length-2. ValueError: If channels are not last dimension. """ if isinstance(params, (tuple, list)): raise NotImplementedError("Bias not yet supported.") if params.shape.ndims != 4: raise ValueError("Filter must be 4-D.") if len(strides) != 4: raise ValueError("strides must account for 4 dimensions.") if rate is not None: if len(rate) != 2: raise ValueError("rate must only account for spatial dimensions.") rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. if not utils.is_data_format_channel_last(data_format): raise ValueError("data_format must be channels-last.") super(DepthwiseConvKFCBasicFB, self).__init__( layer_collection=layer_collection, params=params, padding=padding, strides=strides, dilation_rate=rate, data_format=data_format, extract_patches_fn="extract_image_patches") # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). filter_height, filter_width, in_channels, channel_multiplier = ( params.shape.as_list()) self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) def _multiply_factored_matrix(self, left_factor, right_factor, vector, extra_scale=1.0, transpose_left=False, transpose_right=False): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) conv2d_result = super( DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, transpose_left=transpose_left, transpose_right=transpose_right) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin """Converts a convolution filter for use with conv2d. Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's compatible with tf.nn.conv2d(). Args: filter: Tensor of shape [height, width, in_channels, channel_multiplier]. name: None or str. Name of Op. Returns: Tensor of shape [height, width, in_channels, out_channels]. """ with tf.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", [filter]): filter = tf.convert_to_tensor(filter) filter_height, filter_width, in_channels, channel_multiplier = ( filter.shape.as_list()) results = [] for i in range(in_channels): # Slice out one in_channel's filter. Insert zeros around it to force it # to affect that channel and that channel alone. elements = [] if i > 0: elements.append( tf.zeros([filter_height, filter_width, i, channel_multiplier])) elements.append(filter[:, :, i:(i + 1), :]) if i + 1 < in_channels: elements.append( tf.zeros([ filter_height, filter_width, in_channels - (i + 1), channel_multiplier ])) # Concat along in_channel. results.append(tf.concat(elements, axis=-2, name="in_channel_%d" % i)) # Concat along out_channel. return tf.concat(results, axis=-1, name="out_channel") def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin """Converts a convolution filter for use with depthwise_conv2d. Transforms a filter for use with tf.nn.conv2d() to one that's compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along the diagonal. Args: filter: Tensor of shape [height, width, in_channels, out_channels]. name: None or str. Name of Op. Returns: Tensor of shape, [height, width, in_channels, channel_multiplier] Raises: ValueError: if out_channels is not evenly divisible by in_channels. """ with tf.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", [filter]): filter = tf.convert_to_tensor(filter) filter_height, filter_width, in_channels, out_channels = ( filter.shape.as_list()) if out_channels % in_channels != 0: raise ValueError("out_channels must be evenly divisible by in_channels.") channel_multiplier = out_channels // in_channels results = [] filter = tf.reshape(filter, [ filter_height, filter_width, in_channels, in_channels, channel_multiplier ]) for i in range(in_channels): # Slice out output corresponding to the correct filter. filter_slice = tf.reshape( filter[:, :, i, i, :], [filter_height, filter_width, 1, channel_multiplier]) results.append(filter_slice) # Concat along out_channel. return tf.concat(results, axis=-2, name="in_channels") def maybe_tuple(obj): if not isinstance(obj, list): return obj return tuple(obj) class InputOutputMultiTowerMultiUse(InputOutputMultiTower): """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" def __init__(self, num_uses=None, *args, **kwargs): self._num_uses = num_uses super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) def _process_data(self, grads_list): """Process temporal/multi-use data into the format used by the factors. This function takes inputs and grads_lists data and processes it into one of the formats expected by the FisherFactor classes (depending on the value of the global configuration variable TOWER_STRATEGY). It accepts the data in one of two initial formats. The first possible format is where self._inputs is a list of list of Tensors. The first index is tower, the second is use/time-step. grads_list, meanwhile, is a list over sources of such lists of lists. The second possible data format is where self._inputs is a list of Tensors (over towers), where each tensor either has shape [num_uses, batch_size, ...] or each tensor has shape [num_uses*batch_size, ...] (which is formed by reshaping tensors of the first format). And similarly grads_list is a list over sources of such lists of Tensors. There are two possible formats which inputs and grads_list are transformed into. If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single tensor (represented as a PartitionedTensor object) with all of the data from the towers, as well as the uses/time-steps, concatenated together. The format of this tensor is the same as the second input data format above. Similarly, grads_list is a tuple over sources of such lists of tensors. If TOWER_STRATEGY is "separate" the inputs are formatted into lists of tensors over towers. Each of these tensors has a similar format to the tensor produced by the "concat" option, except that each contains only the data from a single tower. grads_list is similarly formatted into a tuple over sources of such tuples. Args: grads_list: grads_list in its initial format (see above). Returns: inputs: self._inputs transformed into the appropriate format (see above). grads_list: grads_list transformed into the appropriate format (see above). Raises: ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". ValueError: If the given/initial format of self._inputs and grads_list isn't recognized, or doesn't agree with self._num_uses. """ inputs = self._inputs # The first data format. if isinstance(inputs[0], (list, tuple)): num_uses = len(inputs[0]) if self._num_uses is not None and self._num_uses != num_uses: raise ValueError("num_uses argument doesn't match length of inputs.") else: self._num_uses = num_uses # Check that all mini-batches/towers have the same number of uses if not all(len(input_) == num_uses for input_ in inputs): raise ValueError("Length of inputs argument is inconsistent across " "towers.") if fisher_factors.TOWER_STRATEGY == "concat": # Reverse the tower and use/time-step indices, so that use is now first, # and towers is second inputs = tuple(zip(*inputs)) # Flatten the two dimensions inputs = nest.flatten(inputs) # Merge everything together into a PartitionedTensor. We package it in # a singleton tuple since the factors will expect a list over towers inputs = (utils.PartitionedTensor(inputs),) elif fisher_factors.TOWER_STRATEGY == "separate": # Merge together the uses/time-step dimension into PartitionedTensors, # but keep the leading dimension (towers) intact for the factors to # process individually. inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") # The second data format else: inputs = tuple(inputs) # Now we perform the analogous processing for grads_list # The first data format. if isinstance(grads_list[0][0], (list, tuple)): num_uses = len(grads_list[0][0]) if self._num_uses is not None and self._num_uses != num_uses: raise ValueError("num_uses argument doesn't match length of outputs, " "or length of outputs is inconsistent with length of " "inputs.") else: self._num_uses = num_uses if not all(len(grad) == num_uses for grads in grads_list for grad in grads): raise ValueError("Length of outputs argument is inconsistent across " "towers.") if fisher_factors.TOWER_STRATEGY == "concat": # Reverse the tower and use/time-step indices, so that use is now first, # and towers is second grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) # Flatten the two dimensions, leaving the leading dimension (source) # intact grads_list = tuple(nest.flatten(grads) for grads in grads_list) # Merge inner dimensions together into PartitionedTensors. We package # them in a singleton tuple since the factors will expect a list over # towers grads_list = tuple((utils.PartitionedTensor(grads),) for grads in grads_list) elif fisher_factors.TOWER_STRATEGY == "separate": # Merge together the uses/time-step dimension into PartitionedTensors, # but keep the leading dimension (towers) intact for the factors to # process individually. grads_list = tuple(tuple(utils.PartitionedTensor(grad) for grad in grads) for grads in grads_list) else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") # The second data format. else: grads_list = tuple(tuple(grads) for grads in grads_list) if self._num_uses is None: raise ValueError("You must supply a value for the num_uses argument if " "the number of uses cannot be inferred from inputs or " "outputs arguments (e.g. if they are both given in the " "single Tensor format, instead of as lists of Tensors.") return inputs, grads_list class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters. This class implements the "independence across time" approximation from the following paper: https://openreview.net/pdf?id=HyMTkQZAb """ def __init__(self, layer_collection, has_bias=False, num_uses=None, diagonal_approx_for_input=False, diagonal_approx_for_output=False): """Creates a FullyConnectedMultiIndepFB block. Args: layer_collection: LayerCollection instance. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's weights. (Default: False) num_uses: int or None. Number of uses of the layer in the model's graph. Only required if the data is formatted with uses/time folded into the batch dimension (instead of uses/time being a list dimension). (Default: None) diagonal_approx_for_input: Whether to use diagonal approximation for the input Kronecker factor. (Default: False) diagonal_approx_for_output: Whether to use diagonal approximation for the output Kronecker factor. (Default: False) """ self._has_bias = has_bias self._diagonal_approx_for_input = diagonal_approx_for_input self._diagonal_approx_for_output = diagonal_approx_for_output super(FullyConnectedMultiIndepFB, self).__init__( layer_collection=layer_collection, num_uses=num_uses) def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) if self._diagonal_approx_for_input: self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.DiagonalMultiKF, ((inputs,), self._num_uses, self._has_bias)) else: self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, ((inputs,), self._num_uses, self._has_bias)) if self._diagonal_approx_for_output: self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.DiagonalMultiKF, (grads_list, self._num_uses)) else: self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) self._setup_damping(damping, normalization=self._num_uses) @property def _renorm_coeff(self): return float(self._num_uses) class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, KroneckerProductFB): """FisherBlock for 2D convolutional layers using the basic KFC approx. Similar to ConvKFCBasicFB except that this version supports multiple uses/time-steps via a standard independence approximation. Similar to the "independence across time" used in FullyConnectedMultiIndepFB but generalized in the obvious way to conv layers. """ def __init__(self, layer_collection, params, padding, strides=None, dilation_rate=None, data_format=None, extract_patches_fn=None, num_uses=None): """Creates a ConvKFCBasicMultiIndepFB block. Args: layer_collection: The LayerCollection object which owns this block. params: The parameters (Tensor or tuple of Tensors) of this layer. If kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. padding: str. Padding method. strides: List of ints or None. Contains [..spatial_filter_strides..] if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_filter_strides, 1]. dilation_rate: List of ints or None. Rate for dilation along each spatial dimension if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. data_format: str or None. Format of input data. extract_patches_fn: str or None. Name of function that extracts image patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches". num_uses: int or None. Number of uses of the layer in the model's graph. Only required if the data is formatted with uses/time folded into the batch dimension (instead of uses/time being a list dimension). (Default: None) """ self._padding = padding self._strides = maybe_tuple(strides) self._dilation_rate = maybe_tuple(dilation_rate) self._data_format = data_format self._extract_patches_fn = extract_patches_fn self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) super(ConvKFCBasicMultiIndepFB, self).__init__( layer_collection=layer_collection, num_uses=num_uses) def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(), list(self._filter_shape), self._strides, self._padding) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputMultiKF, (inputs, self._filter_shape, self._padding, self._num_uses, self._strides, self._dilation_rate, self._data_format, self._extract_patches_fn, self._has_bias, self._num_uses)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputMultiKF, (grads_list, self._num_uses, self._data_format)) self._setup_damping(damping, normalization=(self._num_locations * self._num_uses)) @property def _renorm_coeff(self): return self._num_locations * self._num_uses class SeriesFBApproximation(object): """See FullyConnectedSeriesFB.__init__ for description and usage.""" option1 = 1 option2 = 2 class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters across time. This class implements the "Option 1" and "Option 2" approximation from the following paper: https://openreview.net/pdf?id=HyMTkQZAb See the end of the appendix of the paper for a pseudo-code of the algorithm being implemented by multiply_matpower here. Note that we are using pre-computed versions of certain matrix-matrix products to speed things up. This is explicitly explained wherever it is done. """ def __init__(self, layer_collection, has_bias=False, num_uses=None, option=SeriesFBApproximation.option2): """Constructs a new `FullyConnectedSeriesFB`. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's weights. num_uses: int or None. Number of time-steps over which the layer is used. Only required if the data is formatted with time folded into the batch dimension (instead of time being a list dimension). (Default: None) option: A `SeriesFBApproximation` specifying the simplifying assumption to be used in this block. `option1` approximates the cross-covariance over time as a symmetric matrix, while `option2` makes the assumption that training sequences are infinitely long. See section 3.5 of the paper for more details. """ self._has_bias = has_bias self._option = option super(FullyConnectedSeriesFB, self).__init__( layer_collection=layer_collection, num_uses=num_uses) @property def _num_timesteps(self): return self._num_uses @property def _renorm_coeff(self): # This should no longer be used since the multiply_X functions from the base # class have been overridden assert False def instantiate_factors(self, grads_list, damping): inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, ((inputs,), self._num_uses, self._has_bias)) self._input_factor.register_cov_dt1() self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) self._output_factor.register_cov_dt1() self._setup_damping(damping, normalization=self._num_uses) def register_matpower(self, exp): if exp != -1: raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" "multiplications.") if self._option == SeriesFBApproximation.option1: self._input_factor.register_option1quants(self._input_damping_func) self._output_factor.register_option1quants(self._output_damping_func) elif self._option == SeriesFBApproximation.option2: self._input_factor.register_option2quants(self._input_damping_func) self._output_factor.register_option2quants(self._output_damping_func) else: raise ValueError( "Unrecognized FullyConnectedSeriesFB approximation: {}".format( self._option)) def multiply_matpower(self, vector, exp): if exp != -1: raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" "multiplications.") # pylint: disable=invalid-name Z = utils.layer_params_to_mat2d(vector) # Derivations were done for "batch_dim==1" case so we need to convert to # that orientation: Z = tf.transpose(Z) if self._option == SeriesFBApproximation.option1: # Note that L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G. L_A, psi_A = self._input_factor.get_option1quants( self._input_damping_func) L_G, psi_G = self._output_factor.get_option1quants( self._output_damping_func) def gamma(x): # We are assuming that each case has the same number of time-steps. # If this stops being the case one shouldn't simply replace this T # with its average value. Instead, one needs to go back to the # definition of the gamma function from the paper. T = self._num_timesteps return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) # Y = \gamma( psi_G*psi_A^T ) (computed element-wise) # Even though Y is Z-independent we are recomputing it from the psi's # each since Y depends on both A and G quantities, and it is relatively # cheap to compute. Y = gamma(tf.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) # Z = L_G^T * Z * L_A # This is equivalent to the following computation from the original # pseudo-code: # Z = G0^{-1/2} * Z * A0^{-1/2} # Z = U_G^T * Z * U_A Z = tf.matmul(L_G, tf.matmul(Z, L_A), transpose_a=True) # Z = Z .* Y Z *= Y # Z = L_G * Z * L_A^T # This is equivalent to the following computation from the original # pseudo-code: # Z = U_G * Z * U_A^T # Z = G0^{-1/2} * Z * A0^{-1/2} Z = tf.matmul(L_G, tf.matmul(Z, L_A, transpose_b=True)) elif self._option == SeriesFBApproximation.option2: # Note that P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}, # and K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G. P_A, K_A, mu_A = self._input_factor.get_option2quants( self._input_damping_func) P_G, K_G, mu_G = self._output_factor.get_option2quants( self._output_damping_func) # Our approach differs superficially from the pseudo-code in the paper # in order to reduce the total number of matrix-matrix multiplies. # In particular, the first three computations in the pseudo code are # Z = G0^{-1/2} * Z * A0^{-1/2} # Z = Z - hPsi_G^T * Z * hPsi_A # Z = E_G^T * Z * E_A # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}, so that # C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2} # the entire computation can be written as # Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2} # - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A # = E_G^T * (G0^{-1/2} * Z * A0^{-1/2} # - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A # = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A # - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A # This final expression is computed by the following two lines: # Z = Z - P_G * Z * P_A^T Z -= tf.matmul(P_G, tf.matmul(Z, P_A, transpose_b=True)) # Z = K_G^T * Z * K_A Z = tf.matmul(K_G, tf.matmul(Z, K_A), transpose_a=True) # Z = Z ./ (1*1^T - mu_G*mu_A^T) # Be careful with the outer product. We don't want to accidentally # make it an inner-product instead. tmp = 1.0 - tf.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A # Prevent some numerical issues by setting any 0.0 eigs to 1.0 tmp += 1.0 * tf.cast(tf.equal(tmp, 0.0), dtype=tmp.dtype) Z /= tmp # We now perform the transpose/reverse version of the operations # derived above, whose derivation from the original pseudo-code is # analgous. # Z = K_G * Z * K_A^T Z = tf.matmul(K_G, tf.matmul(Z, K_A, transpose_b=True)) # Z = Z - P_G^T * Z * P_A Z -= tf.matmul(P_G, tf.matmul(Z, P_A), transpose_a=True) # Z = normalize (1/E[T]) * Z # Note that this normalization is done because we compute the statistics # by averaging, not summing, over time. (And the gradient is presumably # summed over time, not averaged, and thus their scales are different.) Z /= tf.cast(self._num_timesteps, Z.dtype) # Convert back to the "batch_dim==0" orientation. Z = tf.transpose(Z) return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name def multiply_cholesky(self, vector): raise NotImplementedError("FullyConnectedSeriesFB does not support " "Cholesky computations.") def multiply_cholesky_inverse(self, vector): raise NotImplementedError("FullyConnectedSeriesFB does not support " "Cholesky computations.") ================================================ FILE: kfac/python/ops/fisher_factors.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """FisherFactor definitions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import contextlib import math # Dependency imports import numpy as np import six import tensorflow.compat.v1 as tf from collections import OrderedDict from tensorflow.python.util import nest from kfac.python.ops import linear_operator as lo from kfac.python.ops import utils # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). INIT_COVARIANCES_AT_ZERO = True # Whether to zero-debias the moving averages. ZERO_DEBIAS = True # Whether to initialize inverse (and other such matrices computed from the cov # matrices) to the zero matrix (or the identity matrix). Initializing to # zero is a safeguard against anything using the inverse before their first # proper update, and so is preferred. INIT_INVERSES_AT_ZERO = True # When the number of inverses requested from a FisherFactor is >= this value, # the inverses are computed using an eigenvalue decomposition. EIGENVALUE_DECOMPOSITION_THRESHOLD = 4 # Numerical eigenvalues computed from covariance matrix estimates are clipped to # be at least as large as this value before they are used to compute inverses or # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 # When approximating conv layer input factor using spatially uncorrelated # activations (`ConvInputSUAKroneckerfactor`) if this is True then assumes the # activations to have zero mean. ASSUME_ZERO_MEAN_ACTIVATIONS = False # When approximating conv layer input factor using spatially uncorrelated # activations (`ConvInputSUAKroneckerfactor`) if this is True then do # mean subtraction from covariance matrix. Note this flag is only checked in the # case where ASSUME_ZERO_MEAN_ACTIVATIONS is set to True. If # ASSUME_ZERO_MEAN_ACTIVATIONS is False then mean is always subtracted from the # covariance matrix and this flag is redundant. SUBTRACT_MEAN_CONTRIB_FROM_COV = True # Subsample the inputs passed to the extract image patches. The number of # inputs is normally batch_size. If _SUB_SAMPLE_INPUTS = True then # the inputs will be randomly subsampled down to a total of # _INPUTS_TO_EXTRACT_PATCHES_FACTOR * batch_size. # # Note that the value of _SUB_SAMPLE_INPUTS can be overridden locally for a # particular layer by passing in an argument to the factor class (or the # registration function for the corresponding layer). _SUB_SAMPLE_INPUTS = False _INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.2 # Subsample the extracted image patches during covariance estimation for # input factors in conv layer. The number of patches subsampled will be # calculated based on the following formula: # # if _SUB_SAMPLE_PATCHES: # num_patches = min(_MAX_NUM_PATCHES, # ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension)) # else # num_patches = total_patches # # where dimension is the number of rows (or columns) of the input factor matrix, # which is typically the number of input channels times the number of pixels # in a patch. # # Note that the value of _SUB_SAMPLE_PATCHES can be overridden locally for a # particular layer by passing in an argument to the factor class (or the # registration function for the corresponding layer). _SUB_SAMPLE_PATCHES = False _MAX_NUM_PATCHES = 10000000 _MAX_NUM_PATCHES_PER_DIMENSION = 3.0 # If true we use the custom XLA implementation of an op to compute the second # moment of the patch vectors. Note that _SUB_SAMPLE_PATCHES doesn't do anything # when this is enabled. Also note that _SUB_SAMPLE_INPUTS probably doesn't # need to be used either, since that feature was designed to mitigate the # extreme memory consumption of the naive implementation of this op. _USE_PATCHES_SECOND_MOMENT_OP = False # TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data # passed to the factors from the blocks will be concatenated across towers # (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over # towers will be passed in, and the factors will iterate over this and do the # cov computations separately for each one, averaging the results together. TOWER_STRATEGY = "separate" #TOWER_STRATEGY = "concat" # The variable scope names can be edited by passing a custom sanitizer function. # By default the scope name is unchanged. _GET_SANITIZED_NAME_FN = lambda x: x def set_global_constants(init_covariances_at_zero=None, zero_debias=None, init_inverses_at_zero=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, assume_zero_mean_activations=None, subtract_mean_contrib_from_cov=None, sub_sample_inputs=None, inputs_to_extract_patches_factor=None, sub_sample_patches=None, max_num_patches=None, max_num_patches_per_dimension=None, tower_strategy=None, get_sanitized_name_fn=None, use_patches_second_moment_op=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global INIT_INVERSES_AT_ZERO global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD global ASSUME_ZERO_MEAN_ACTIVATIONS global SUBTRACT_MEAN_CONTRIB_FROM_COV global _SUB_SAMPLE_INPUTS global _INPUTS_TO_EXTRACT_PATCHES_FACTOR global _SUB_SAMPLE_PATCHES global _MAX_NUM_PATCHES global _MAX_NUM_PATCHES_PER_DIMENSION global _GET_SANITIZED_NAME_FN global TOWER_STRATEGY global _USE_PATCHES_SECOND_MOMENT_OP if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero if zero_debias is not None: ZERO_DEBIAS = zero_debias if init_inverses_at_zero is not None: INIT_INVERSES_AT_ZERO = init_inverses_at_zero if eigenvalue_decomposition_threshold is not None: EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold if assume_zero_mean_activations is not None: ASSUME_ZERO_MEAN_ACTIVATIONS = assume_zero_mean_activations if subtract_mean_contrib_from_cov is not None: SUBTRACT_MEAN_CONTRIB_FROM_COV = subtract_mean_contrib_from_cov if sub_sample_inputs is not None: _SUB_SAMPLE_INPUTS = sub_sample_inputs if inputs_to_extract_patches_factor is not None: _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor if sub_sample_patches is not None: _SUB_SAMPLE_PATCHES = sub_sample_patches if max_num_patches is not None: _MAX_NUM_PATCHES = max_num_patches if max_num_patches_per_dimension is not None: _MAX_NUM_PATCHES_PER_DIMENSION = max_num_patches_per_dimension if tower_strategy is not None: TOWER_STRATEGY = tower_strategy if get_sanitized_name_fn is not None: _GET_SANITIZED_NAME_FN = get_sanitized_name_fn if use_patches_second_moment_op is not None: _USE_PATCHES_SECOND_MOMENT_OP = use_patches_second_moment_op if INIT_INVERSES_AT_ZERO: inverse_initializer = tf.zeros_initializer else: inverse_initializer = tf.initializers.identity if INIT_COVARIANCES_AT_ZERO: covariance_initializer = tf.zeros_initializer else: covariance_initializer = tf.initializers.identity if INIT_COVARIANCES_AT_ZERO: diagonal_covariance_initializer = tf.zeros_initializer else: diagonal_covariance_initializer = tf.ones_initializer @contextlib.contextmanager def maybe_place_on_device(device): if device is not None and len(device) and TOWER_STRATEGY == "separate": with tf.device(device): yield else: yield def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row mean is zero, so that the true second moment equals the true covariance. Args: tensor: A 2D Tensor. tensor_right: An optional 2D Tensor. If provided, this function computes the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). Returns: A square 2D Tensor with as many rows/cols as the number of input columns. """ if normalizer is None: normalizer = utils.get_shape(tensor)[0] if tensor_right is None: cov = ( tf.matmul(tensor, tensor, transpose_a=True) / tf.cast( normalizer, tensor.dtype)) return (cov + tf.transpose(cov)) / tf.cast(2.0, cov.dtype) else: return (tf.matmul(tensor, tensor_right, transpose_a=True) / tf.cast(normalizer, tensor.dtype)) def append_homog(tensor, homog_value=None): """Appends a homogeneous coordinate to the last dimension of a Tensor. Args: tensor: A Tensor. homog_value: Value to append as homogeneous coordinate to the last dimension of `tensor`. If None 1.0 is used. (Default: None) Returns: A Tensor identical to the input but one larger in the last dimension. The new entries are filled with ones. """ shape = tensor.shape.as_list() rank = len(shape) if any(elt is None for elt in shape): shape = tf.concat([tf.shape(tensor)[:-1], [1]], axis=0) else: shape[-1] = 1 if homog_value is not None: appendage = homog_value * tf.ones(shape, dtype=tensor.dtype) else: appendage = tf.ones(shape, dtype=tensor.dtype) return tf.concat([tensor, appendage], axis=-1) def scope_string_from_params(params): """Builds a variable scope string name from the given parameters. Supported parameters are: * tensors * booleans * ints * strings * depth-1 tuples/lists of ints * any depth tuples/lists of tensors Other parameter types will throw an error. Args: params: A parameter or list of parameters. Returns: A string to use for the variable scope. Raises: ValueError: if params includes an unsupported type. """ params = params if isinstance(params, (tuple, list)) else (params,) name_parts = [] for param in params: if param is None: name_parts.append("None") elif isinstance(param, (tuple, list)): if all([isinstance(p, int) for p in param]): name_parts.append("-".join([str(p) for p in param])) else: name_parts.append(scope_string_from_name(param)) elif isinstance(param, (six.string_types, int, bool)): name_parts.append(str(param)) elif isinstance(param, (tf.Tensor, tf.Variable)): name_parts.append(scope_string_from_name(param)) elif isinstance(param, utils.PartitionedTensor): name_parts.append(scope_string_from_name(param.tensors)) else: raise ValueError("Encountered an unsupported param {} of type {}".format( param, type(param))) return "_".join(name_parts) def scope_string_from_name(tensor): if isinstance(tensor, (tuple, list)): return "__".join([scope_string_from_name(t) for t in tensor]) # "gradients/add_4_grad/Reshape:0/replica_0" -> # "gradients_add_4_grad_Reshape_0_replica_0" tensor_name = tensor.name.replace("/", "_").replace(":", "_") return _GET_SANITIZED_NAME_FN(tensor_name) def scalar_or_tensor_to_string(val): return repr(val) if np.isscalar(val) else scope_string_from_name(val) def list_to_string(lst): return "_".join(val if isinstance(val, six.string_types) else scalar_or_tensor_to_string(val) for val in lst) def graph_func_to_id(func): """Returns a hashable object that represents func's computation.""" # TODO(b/74201126): replace with Topohash of func's output return func.func_id def graph_func_to_string(func): # TODO(b/74201126): replace with Topohash of func's output return list_to_string(func.func_id) def _subsample_patches(patches, name=None): """Subsample a patches matrix. Subsample an array of image patches. The number of patches subsampled will be calculated based on the following formula: num_patches = min(_MAX_NUM_PATCHES, ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension)) Args: patches: Tensor, of shape `[total_patches, dimension]`. name: `string`, Default (None) Returns: A tensor of shape `[num_patches, dimension]`. Raises: ValueError: If patches is not matrix-shaped. ValueError: If total_patches cannot be inferred. """ with tf.name_scope(name, "subsample", [patches]): patches = tf.convert_to_tensor(patches) if len(patches.shape) != 2: raise ValueError("Input param patches must be a matrix.") total_patches = patches.shape.as_list()[0] dimension = patches.shape.as_list()[1] num_patches = min(_MAX_NUM_PATCHES, int(math.ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension))) if total_patches is None: total_patches = utils.get_shape(patches)[0] should_subsample = tf.less(num_patches, total_patches) return tf.cond(should_subsample, lambda: _random_tensor_gather(patches, num_patches, name), lambda: patches) else: if num_patches < total_patches: return _random_tensor_gather(patches, num_patches, name) else: return patches def _random_tensor_gather(array, num_ind, name=None): """Samples random indices of an array (along the first dimension). Args: array: Tensor of shape `[batch_size, ...]`. num_ind: int. Number of indices to sample. name: `string`. (Default: None) Returns: A tensor of shape `[num_ind, ...]`. """ with tf.name_scope(name, "random_gather", [array]): array = tf.convert_to_tensor(array) total_size = array.shape.as_list()[0] if total_size is None: total_size = utils.get_shape(array)[0] indices = tf.random_shuffle( tf.range(0, total_size, dtype=utils.preferred_int_dtype()))[:num_ind] return tf.gather(array, indices, axis=0) @six.add_metaclass(abc.ABCMeta) class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. A FisherFactor represents part of an approximate Fisher Information matrix. For example, one approximation to the Fisher uses the Kronecker product of two FisherFactors A and B, F = kron(A, B). FisherFactors are composed with FisherBlocks to construct a block-diagonal approximation to the full Fisher. FisherFactors are backed by a single, non-trainable variable that is updated by running FisherFactor.make_covariance_update_op(). The shape and type of this variable is implementation specific. Note that for blocks that aren't based on approximations, a 'factor' can be the entire block itself, as is the case for the diagonal and full representations. """ def __init__(self): self._cov_tensor = None self._cov = None self._acc_cov = None @abc.abstractproperty def _var_scope(self): """Variable scope for this FisherFactor instance. Returns: string that unique identifies this FisherFactor instance. """ pass @property def name(self): return self._var_scope @abc.abstractproperty def _cov_shape(self): """The shape of the variable backing this FisherFactor.""" pass @abc.abstractproperty def _num_sources(self): """The number of things to sum over when updating covariance variable. The default make_covariance_update_op function will call _compute_new_cov with indices ranging from 0 to _num_sources-1. The typical situation is where the factor wants to sum the statistics it computes over multiple backpropped "gradients" (typically passed in via "tensors" or "outputs_grads" arguments). """ pass @abc.abstractproperty def _num_towers(self): pass @abc.abstractproperty def _dtype(self): """dtype for variable backing this factor.""" pass @abc.abstractmethod def _partial_batch_size(self, source=0, tower=0): """Returns (partial) batch size associated with given source and tower.""" pass def batch_size(self, source=0): """Returns (total) batch size associated with given source.""" return sum(self._partial_batch_size(source=source, tower=tower) for tower in range(self._num_towers)) def check_partial_batch_sizes(self): """Ensures partial batch sizes are equal across towers and source.""" # While it could be okay in principle to have different batch sizes for # different towers, the way the code has been written isn't compatible with # this. Basically, the normalizations occur for each tower and then the # results are summed across towers and divided by the number of towers. # The only way this is correct is if the towers all have the same batch # size. # Should make these messages use quote characters instead of parentheses # when the bug with quote character rendering in assertion messages is # fixed. See b/129476712 msg = ("Inconsistent (partial) batch sizes detected for factor ({}) of type" " {}. This can be caused by passing Tensors with the wrong sizes to " "the registration functions, or misspecification of arguments like " "batch_size, num_uses, or num_timesteps.".format( self.name, utils.cls_name(self))) partial_batch_size = self._partial_batch_size() if self._num_sources > 1 or self._num_towers > 1: if isinstance(partial_batch_size, int): checks = tuple( partial_batch_size == self._partial_batch_size(source=source, tower=tower) for source, tower in zip(range(self._num_sources), range(self._num_towers))) if not all(checks): raise ValueError(msg) return tf.no_op() else: asserts = tuple( tf.assert_equal(partial_batch_size, self._partial_batch_size(source=source, tower=tower), message=msg) for source, tower in zip(range(self._num_sources), range(self._num_towers))) return tf.group(asserts) return tf.no_op() @property def _cov_initializer(self): """Function for initializing covariance variable.""" return covariance_initializer def instantiate_cov_variables(self): """Makes the internal cov variable(s).""" assert self._cov is None with tf.variable_scope(self._var_scope): self._cov = utils.MovingAverageVariable( name="cov", shape=self._cov_shape, dtype=self._dtype, initializer=self._cov_initializer, normalize_value=ZERO_DEBIAS) @abc.abstractmethod def _compute_new_cov(self, source, tower): """Computes minibatch-estimated covariance for a single source. Args: source: int in [0, self._num_sources). Which source to use when computing the cov update. tower: int in [0, self._num_towers). Which tower to use when computing the cov update. Returns: Tensor of same shape as self.cov. """ pass def _compute_total_new_cov(self): """Computes covariance by summing across (source, towers).""" new_cov_contribs = [] for source in range(self._num_sources): for tower in range(self._num_towers): with maybe_place_on_device(self._get_data_device(tower)): new_cov_contribs.append(self._compute_new_cov(source, tower)) new_cov = tf.add_n(new_cov_contribs) / float(self._num_towers) # Compute average of 'new_cov' across all replicas. On a replica, each # instance of 'new_cov' will be based on a different minibatch. This ensures # that by the time variable assignment happens, all replicas have the same # value. # # Other implementations of make_covariance_update_op() that accumulate # statistics in other variables should mimic this behavior. # # NOTE: communicating this matrix at every iteration is wasteful in the # sense that we might only need fresh copies when we do the inversions. # (Although be careful about factors [e.g. diagonal] or ops # [e.g. multiply()] that directly use the cov vars instead of the inv vars!) new_cov = utils.all_average(new_cov) return new_cov def make_covariance_update_op(self, ema_decay, ema_weight): """Constructs and returns the covariance update Op. Args: ema_decay: float or Tensor. The exponential moving average decay. ema_weight: float or Tensor. The weight to put on the newly computed values. This is typically 1.0 - ema_decay. Returns: The op which updates the cov variable (via acc_cov). """ cov_tensor = self._compute_total_new_cov() self._cov_tensor = cov_tensor # This is used for non-standard applications # and debugging I think. return self._cov.add_to_average(cov_tensor, decay=ema_decay, weight=ema_weight) @abc.abstractmethod def _get_data_device(self, tower): pass @abc.abstractmethod def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" pass @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" pass @property def cov(self): return self._cov.value def get_cov_vars(self): return [self.cov] def get_inv_vars(self): return [] @abc.abstractmethod def get_cov_as_linear_operator(self): """Returns `LinearOperator` instance which wraps the cov matrix.""" pass @abc.abstractmethod def register_matpower(self, exp, damping_func): pass @abc.abstractmethod def register_cholesky(self, damping_func): pass @abc.abstractmethod def register_cholesky_inverse(self, damping_func): pass @abc.abstractmethod def get_matpower(self, exp, damping_func): pass @abc.abstractmethod def get_cholesky(self, damping_func): pass @abc.abstractmethod def get_cholesky_inverse(self, damping_func): pass class DenseSquareMatrixFactor(FisherFactor): """Base class for FisherFactors that are stored as dense square matrices. This class explicitly calculates and stores inverses of their `cov` matrices, which must be square dense matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. """ # TODO(b/69108481): This class (and its subclasses) should be refactored to # serve the matrix quantities it computes as both (potentially stale) # variables, updated by the inverse update ops, and fresh values stored in # tensors that recomputed once every session.run() call. Currently matpower # and damp_inverse have the former behavior, while eigendecomposition has # the latter. def __init__(self): self._matpower_by_exp_and_damping = OrderedDict() # { (float, hashable): variable } self._matpower_registrations = set() # { (float, hashable) } self._eigendecomp = None self._damping_funcs_by_id = OrderedDict() # {hashable: lambda} self._cholesky_registrations = set() # { hashable } self._cholesky_inverse_registrations = set() # { hashable } self._cholesky_by_damping = OrderedDict() # { hashable: variable } self._cholesky_inverse_by_damping = OrderedDict() # { hashable: variable } super(DenseSquareMatrixFactor, self).__init__() def get_cov_as_linear_operator(self): """Returns `LinearOperator` instance which wraps the cov matrix.""" assert self.cov.shape.ndims == 2 return lo.LinearOperatorFullMatrix(self.cov, is_self_adjoint=True, is_square=True) def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) if damping_id not in self._damping_funcs_by_id: self._damping_funcs_by_id[damping_id] = damping_func return damping_id def register_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests self.register_matpower(-1, damping_func) def register_matpower(self, exp, damping_func): """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower. Args: exp: float. The exponent to use in the matrix power. damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: return damping_id = self._register_damping(damping_func) if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) def register_cholesky(self, damping_func): """Registers a Cholesky factor to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_cholesky. Args: damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ damping_id = self._register_damping(damping_func) if damping_id not in self._cholesky_registrations: self._cholesky_registrations.add(damping_id) def register_cholesky_inverse(self, damping_func): """Registers an inverse Cholesky factor to be maintained/served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_cholesky_inverse. Args: damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ damping_id = self._register_damping(damping_func) if damping_id not in self._cholesky_inverse_registrations: self._cholesky_inverse_registrations.add(damping_id) def get_inv_vars(self): inv_vars = [] inv_vars.extend(self._matpower_by_exp_and_damping.values()) inv_vars.extend(self._cholesky_by_damping.values()) inv_vars.extend(self._cholesky_inverse_by_damping.values()) return inv_vars def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" for (exp, damping_id) in self._matpower_registrations: exp_string = scalar_or_tensor_to_string(exp) damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): matpower = tf.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower for damping_id in self._cholesky_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): chol = tf.get_variable( "cholesky_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._cholesky_by_damping self._cholesky_by_damping[damping_id] = chol for damping_id in self._cholesky_inverse_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): cholinv = tf.get_variable( "cholesky_inverse_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._cholesky_inverse_by_damping self._cholesky_inverse_by_damping[damping_id] = cholinv def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping if exp == -1) num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses other_matrix_power_registered = num_other_matpower >= 1 use_eig = ( self._eigendecomp or other_matrix_power_registered or num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) # We precompute these so we don't need to evaluate them multiple times (for # each matrix power that uses them) damping_value_by_id = {damping_id: tf.cast( self._damping_funcs_by_id[damping_id](), self._dtype) for damping_id in self._damping_funcs_by_id} if use_eig: eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence for (exp, damping_id), matpower in ( self._matpower_by_exp_and_damping.items()): damping = damping_value_by_id[damping_id] ops.append( utils.smart_assign( matpower, tf.matmul(eigenvectors * (eigenvalues + damping)**exp, tf.transpose(eigenvectors)))) # These ops share computation and should be run on a single device. ops = [tf.group(*ops)] else: for (exp, damping_id), matpower in ( self._matpower_by_exp_and_damping.items()): assert exp == -1 damping = damping_value_by_id[damping_id] ops.append( utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping))) # TODO(b/77902055): If inverses are being computed with Cholesky's # we can share the work. Instead this code currently just computes the # Cholesky a second time. It does at least share work between requests for # Cholesky's and Cholesky inverses with the same damping id. for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): cholesky_ops = [] damping = damping_value_by_id[damping_id] cholesky_value = utils.cholesky(self.cov, damping) if damping_id in self._cholesky_by_damping: cholesky = self._cholesky_by_damping[damping_id] cholesky_ops.append(utils.smart_assign(cholesky, cholesky_value)) identity = tf.eye( cholesky_value.shape.as_list()[0], dtype=cholesky_value.dtype) cholesky_inv_value = tf.matrix_triangular_solve(cholesky_value, identity) cholesky_ops.append(utils.smart_assign(cholesky_inv, cholesky_inv_value)) ops.append(tf.group(*cholesky_ops)) for damping_id, cholesky in self._cholesky_by_damping.items(): if damping_id not in self._cholesky_inverse_by_damping: damping = damping_value_by_id[damping_id] cholesky_value = utils.cholesky(self.cov, damping) ops.append(utils.smart_assign(cholesky, cholesky_value)) self._eigendecomp = False return ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests return self.get_matpower(-1, damping_func) def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov (except when exp == 1). if exp != 1: damping_id = graph_func_to_id(damping_func) matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] else: cov = self.cov identity = tf.eye(cov.shape.as_list()[0], dtype=cov.dtype) matpower = cov + tf.cast(damping_func(), dtype=self.cov.dtype)*identity assert matpower.shape.ndims == 2 return lo.LinearOperatorFullMatrix(matpower, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) def get_cholesky(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov. damping_id = graph_func_to_id(damping_func) cholesky = self._cholesky_by_damping[damping_id] assert cholesky.shape.ndims == 2 return lo.LinearOperatorFullMatrix(cholesky, is_non_singular=True, is_square=True) def get_cholesky_inverse(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov. damping_id = graph_func_to_id(damping_func) cholesky_inv = self._cholesky_inverse_by_damping[damping_id] assert cholesky_inv.shape.ndims == 2 return lo.LinearOperatorFullMatrix(cholesky_inv, is_non_singular=True, is_square=True) def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_matpower this doesn't retrieve a stored variable, but instead # always computes a fresh version from the current value of self.cov. if not self._eigendecomp: eigenvalues, eigenvectors = tf.self_adjoint_eig(self.cov) # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here # we clip them to be at least FLAGS.eigenvalue_clipping_threshold clipped_eigenvalues = tf.maximum(eigenvalues, EIGENVALUE_CLIPPING_THRESHOLD) self._eigendecomp = (clipped_eigenvalues, eigenvectors) return self._eigendecomp class NaiveFullFactor(DenseSquareMatrixFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. Note that this uses the naive "square the sum estimator", and so is applicable to any type of parameter in principle, but has very high variance. """ def __init__(self, params_grads, batch_size): self._batch_size = batch_size self._params_grads = tuple(utils.ensure_sequence(params_grad) for params_grad in params_grads) super(NaiveFullFactor, self).__init__() @property def _var_scope(self): return "ff_naivefull_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property def _cov_shape(self): size = sum(param_grad.shape.num_elements() for param_grad in self._params_grads[0]) return (size, size) @property def _num_sources(self): return len(self._params_grads) @property def _num_towers(self): return 1 @property def _dtype(self): return self._params_grads[0][0].dtype def _partial_batch_size(self, source=0, tower=0): assert source == 0 and tower == 0 return self._batch_size def _compute_new_cov(self, source, tower): assert tower == 0 # This will be a very basic rank 1 estimate params_grads_flat = utils.tensors_to_column(self._params_grads[source]) return ((params_grads_flat * tf.transpose(params_grads_flat)) / tf.cast( self._batch_size, params_grads_flat.dtype)) def _get_data_device(self, tower): return None @six.add_metaclass(abc.ABCMeta) class DiagonalFactor(FisherFactor): """A base class for FisherFactors that use diagonal approximations. A DiagonalFactor's covariance variable can be of any shape, but must contain exactly one entry per parameter. """ def get_cov_as_linear_operator(self): """Returns `LinearOperator` instance which wraps the cov matrix.""" return lo.LinearOperatorDiag(self._matrix_diagonal, is_self_adjoint=True, is_square=True) @property def _cov_initializer(self): return diagonal_covariance_initializer @property def _matrix_diagonal(self): return tf.reshape(self.cov, [-1]) def make_inverse_update_ops(self): return [] def instantiate_inv_variables(self): pass def register_matpower(self, exp, damping_func): pass def register_cholesky(self, damping_func): pass def register_cholesky_inverse(self, damping_func): pass def get_matpower(self, exp, damping_func): matpower_diagonal = (self._matrix_diagonal + tf.cast(damping_func(), self._dtype))**exp return lo.LinearOperatorDiag(matpower_diagonal, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) def get_cholesky(self, damping_func): return self.get_matpower(0.5, damping_func) def get_cholesky_inverse(self, damping_func): return self.get_matpower(-0.5, damping_func) class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. Note that this uses the naive "square the sum estimator", and so is applicable to any type of parameter in principle, but has very high variance. """ def __init__(self, params_grads, batch_size): """Initializes NaiveDiagonalFactor instance. Args: params_grads: List of tensors (or lists), with the first index corresponding to source, and the second optional index corresponding to the element of the parameter list. batch_size: int or 0-D Tensor. The batch size. """ self._params_grads = params_grads self._batch_size = batch_size super(NaiveDiagonalFactor, self).__init__() @property def _var_scope(self): return "ff_naivediag_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property def _cov_shape(self): return self._params_grads[0].shape @property def _num_sources(self): return len(self._params_grads) @property def _num_towers(self): return 1 @property def _dtype(self): return self._params_grads[0].dtype def _partial_batch_size(self, source=0, tower=0): assert source == 0 and tower == 0 return self._batch_size def _compute_new_cov(self, source, tower): assert tower == 0 return (tf.square(self._params_grads[source]) / tf.cast( self._batch_size, self._params_grads[source].dtype)) def _get_data_device(self, tower): return None class DiagonalKroneckerFactor(DiagonalFactor): """A Kronecker FisherFactor using diagonal approximations. This class handles both sparse and dense inputs. The covariance is estimated using the diagonal covariance matrix. For a dense tensor: Cov(inputs, inputs) = (1/batch_size) sum_{i} diag(inputs[i,:] ** 2). For sparse inputs, one of the most common use cases is the sparse input to an embedding layer. Given tensor = [batch_size, input_size] representing indices into an [vocab_size, embedding_size] embedding matrix, the diagonal covariance matrix is Cov(inputs, inputs) = (1/batch_size) sum_{i} diag(n_hot(inputs[i]) ** 2). where inputs[i] is the ith list of input ids, n_hot() constructs an n-hot binary vector and diag() constructs a diagonal matrix of size [vocab_size, vocab_size]. """ def __init__(self, tensors, has_bias=False, dtype=None): """Instantiate DiagonalKroneckerFactor. Args: tensors: List of list of Tensors, each of shape [batch_size, n]. First index is source, second index is tower. Two types of tensors are supported. Dense tensors are typically either a layer's inputs or its output's gradients. Sparse tensors are typically indices into an [vocab_size, embedding_dim] embedding matrix. Sparse tensors must have a property named "one_hot_depth" indicating the depth of one-hot tensors they should be converted to. dtype: dtype for covariance statistics. Only used for sparse inputs. Must be a floating point type. Defaults to float32. has_bias: bool. If True, append '1' to each input. """ self._tensors = tensors dtype = dtype or tf.float32 self._has_bias = has_bias self._one_hot_depth = getattr(self._tensors[0][0], "one_hot_depth", None) if self._one_hot_depth is None: self._dense_input = True self._cov_dtype = self._tensors[0][0].dtype else: self._dense_input = False self._cov_dtype = dtype super(DiagonalKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_diag_kron_" + scope_string_from_params( nest.flatten(self._tensors)) @property def _cov_shape(self): if self._dense_input: size = self._tensors[0][0].shape[1] + self._has_bias else: size = self._one_hot_depth + self._has_bias return [size] @property def _num_sources(self): return len(self._tensors) @property def _num_towers(self): return len(self._tensors[0]) @property def _dtype(self): return self._cov_dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._tensors[source][tower])[0] def _compute_new_cov(self, source, tower): return self._compute_new_cov_from_tensor(self._tensors[source][tower]) def _compute_new_cov_from_tensor(self, tensor): batch_size = utils.get_shape(tensor)[0] if self._dense_input: if len(tensor.shape) != 2: raise ValueError( "Dense input tensors to DiagonalKroneckerFactor must have " "rank == 2. Found tensor with wrong rank: {}".format(tensor)) new_cov = tf.square(tensor) else: if len(tensor.shape) != 1: raise ValueError( "Sparse input tensors to DiagonalKroneckerFactor must have " "rank == 1. Found tensor with wrong rank: {}".format(tensor)) # Transform indices into one-hot vectors. # # TODO(b/72714822): There must be a faster way to construct the diagonal # covariance matrix! This operation is O(batch_size * vocab_size), where # it should be O(batch_size * input_size). flat_input_ids = tf.reshape(tensor, [-1]) new_cov = tf.one_hot(flat_input_ids, self._one_hot_depth) # [?, vocab_size] # Take average across examples. Note that, because all entries have # magnitude zero or one, there's no need to square the entries. # # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation # within an example such as average. # # TODO(b/72714822): Support for partitioned embeddings. new_cov = tf.reduce_sum(new_cov, axis=0) new_cov /= tf.cast(batch_size, new_cov.dtype) if self._has_bias: new_cov = append_homog(new_cov) return new_cov def _get_data_device(self, tower): return self._tensors[0][tower].device class DiagonalMultiKF(DiagonalKroneckerFactor): def __init__(self, tensors, num_uses, has_bias=False, dtype=None): super(DiagonalMultiKF, self).__init__( tensors, dtype=dtype, has_bias=has_bias) self._num_uses = num_uses def _partial_batch_size(self, source=0, tower=0): # Note that some internal comptutations of "batch_size" done in the parent # class won't actually be the proper batch size. Instead, they will be # just "the thing to normalize the statistics by", essentially. This is okay # as we don't mix the two things up. shape = utils.get_shape(self._tensors[source][tower]) if self._dense_input: if len(shape) == 2: # the folded case return shape[0] // self._num_uses elif len(shape) == 3: return shape[1] # batch is the second dim else: if len(shape) == 1: # the folded case return shape[0] // self._num_uses elif len(shape) == 2: return shape[1] # batch is the second dim @property def _cov_shape(self): if self._dense_input: shape = self._tensors[0][0].shape if len(shape) == 2: size = shape[1] + self._has_bias elif len(shape) == 3: size = shape[2] + self._has_bias else: size = self._one_hot_depth + self._has_bias return [size] def _compute_new_cov(self, source, tower): tensor = self._tensors[source][tower] if self._dense_input: if len(tensor.shape) == 3: tensor = tf.reshape(tensor, [-1, tensor.shape[2]]) else: if len(tensor.shape) == 2: tensor = tf.reshape(tensor, [-1]) return self._compute_new_cov_from_tensor(tensor) class FullyConnectedDiagonalFactor(DiagonalFactor): r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], approximates the covariance as, Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 where the square is taken element-wise. """ def __init__(self, inputs, outputs_grads, has_bias=False): """Instantiate FullyConnectedDiagonalFactor. Args: inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this layer. List index is towers. outputs_grads: List of Tensors, each of shape [batch_size, output_size], which are the gradients of the loss with respect to the layer's outputs. First index is source, second is tower. has_bias: bool. If True, append '1' to each input. """ self._inputs = inputs self._has_bias = has_bias self._outputs_grads = outputs_grads self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): return "ff_diagfc_" + scope_string_from_params( tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): input_size = self._inputs[0].shape[1] + self._has_bias output_size = self._outputs_grads[0][0].shape[1] return [input_size, output_size] @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._outputs_grads[0][0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def make_covariance_update_op(self, ema_decay, ema_weight): self._squared_inputs = [] for tower in range(self._num_towers): inputs = self._inputs[tower] with maybe_place_on_device(self._get_data_device(tower)): if self._has_bias: inputs = append_homog(inputs) self._squared_inputs.append(tf.square(inputs)) return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( ema_decay, ema_weight) def _compute_new_cov(self, source, tower): batch_size = utils.get_shape(self._squared_inputs[tower])[0] outputs_grad = self._outputs_grads[source][tower] # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. new_cov = tf.matmul( self._squared_inputs[tower], tf.square(outputs_grad), transpose_a=True) new_cov /= tf.cast(batch_size, new_cov.dtype) return new_cov def _get_data_device(self, tower): return self._inputs[tower].device @six.add_metaclass(abc.ABCMeta) class ScaleAndShiftFactor(FisherFactor): def __init__(self, inputs, outputs_grads, broadcast_dims_scale, broadcast_dims_shift=None, has_shift=True, approx="full"): assert approx == "full" or approx == "diagonal" self._inputs = inputs self._outputs_grads = outputs_grads self._broadcast_dims_scale = broadcast_dims_scale self._broadcast_dims_shift = broadcast_dims_shift self._has_shift = has_shift self._approx = approx assert not has_shift or broadcast_dims_shift is not None super(ScaleAndShiftFactor, self).__init__() @property def _var_scope(self): return "ff_scaleshift_" + scope_string_from_params( [self._inputs, self._outputs_grads, self._broadcast_dims_scale, self._broadcast_dims_shift, self._has_shift, self._approx]) @property def _cov_shape(self): size = np.prod([ self._inputs[0].shape[i] for i in range(1, len(self._inputs[0].shape)) if i not in self._broadcast_dims_scale], dtype=np.int64) if self._has_shift: size_shift = np.prod([ self._outputs_grads[0][0].shape[i] for i in range(1, len(self._outputs_grads[0][0].shape)) if i not in self._broadcast_dims_shift], dtype=np.int64) size += size_shift if self._approx == "full": return (size, size) elif self._approx == "diagonal": return (size,) @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def _compute_new_cov(self, source, tower): # Here we implement a "sum of squares" estimator that uses the special # structure of the scale & shift operation. In particular, we sum across # all dimensions that broadcast, then square (or take outer-products), and # then average across the mini-batch. inputs = self._inputs[tower] outputs_grad = self._outputs_grads[source][tower] batch_size = utils.get_shape(inputs)[0] assert len(inputs.shape) == len(outputs_grad.shape) for i in range(1, len(inputs.shape)): assert inputs.shape[i] <= outputs_grad.shape[i] # The formula for the gradient of the shift param is just the element-wise # product of the inputs and the output gradients, summed across the # dimensions that get broadcasted. scale_grads = tf.reduce_sum(inputs * outputs_grad, axis=self._broadcast_dims_scale) scale_grads_flat = tf.reshape(scale_grads, [batch_size, -1]) if self._has_shift: # The formula for the gradient of the shift param is just the output # gradients, summed across the dimensions that get broadcasted. shift_grads = tf.reduce_sum(outputs_grad, axis=self._broadcast_dims_shift) shift_grads_flat = tf.reshape(shift_grads, [batch_size, -1]) params_grads_flat = tf.concat([scale_grads_flat, shift_grads_flat], axis=1) else: params_grads_flat = scale_grads_flat if self._approx == "full": new_cov = compute_cov(params_grads_flat) elif self._approx == "diagonal": new_cov = tf.reduce_mean(tf.square(params_grads_flat), axis=0) return new_cov def _get_data_device(self, tower): return self._inputs[tower].device class ScaleAndShiftFullFactor(ScaleAndShiftFactor, DenseSquareMatrixFactor): def __init__(self, inputs, outputs_grads, broadcast_dims_scale, broadcast_dims_shift=None, has_shift=True): super(ScaleAndShiftFullFactor, self).__init__( inputs, outputs_grads, broadcast_dims_scale, broadcast_dims_shift=broadcast_dims_shift, has_shift=has_shift, approx="full") class ScaleAndShiftDiagonalFactor(ScaleAndShiftFactor, DiagonalFactor): def __init__(self, inputs, outputs_grads, broadcast_dims_scale, broadcast_dims_shift=None, has_shift=True): super(ScaleAndShiftDiagonalFactor, self).__init__( inputs, outputs_grads, broadcast_dims_scale, broadcast_dims_shift=broadcast_dims_shift, has_shift=has_shift, approx="diagonal") class ConvDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, data_format=None, dilations=None, has_bias=False, patch_mask=None): """Creates a ConvDiagonalFactor object. Args: inputs: List of Tensors of shape [batch_size, height, width, in_channels]. Input activations to this layer. List index is towers. outputs_grads: List of Tensors, each of shape [batch_size, height, width, out_channels], which are the gradients of the loss with respect to the layer's outputs. First index is source, second index is tower. filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). data_format: None or str. Format of conv2d inputs. dilations: None or tuple of 4 ints. has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed. (Default: None) Raises: ValueError: If inputs, output_grads, and filter_shape do not agree on in_channels or out_channels. ValueError: If strides, dilations are not length-4 lists of ints. ValueError: If data_format does not put channel last. """ if not utils.is_data_format_channel_last(data_format): raise ValueError("Channel must be last.") if any(input_.shape.ndims != 4 for input_ in inputs): raise ValueError("inputs must be a list of 4-D Tensors.") if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): raise ValueError("inputs and filter_shape must agree on in_channels.") for i, outputs_grad in enumerate(outputs_grads): if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): raise ValueError("outputs[%d] must be 4-D Tensor." % i) if any(output_grad.shape.as_list()[-1] != filter_shape[-1] for output_grad in outputs_grad): raise ValueError( "outputs[%d] and filter_shape must agree on out_channels." % i) if len(strides) != 4: raise ValueError("strides must be length-4 list of ints.") if dilations is not None and len(dilations) != 4: raise ValueError("dilations must be length-4 list of ints.") self._inputs = inputs self._outputs_grads = outputs_grads self._filter_shape = filter_shape self._strides = strides self._padding = padding self._data_format = data_format self._dilations = dilations self._has_bias = has_bias self._patches = None self._patch_mask = patch_mask super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): return "ff_convdiag_" + scope_string_from_params( tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape return [ filter_height * filter_width * in_channels + self._has_bias, out_channels ] @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def make_covariance_update_op(self, ema_decay, ema_weight): filter_height, filter_width, _, _ = self._filter_shape # TODO(b/64144716): there is potential here for a big savings in terms # of memory use. if self._dilations is None: rates = (1, 1, 1, 1) else: rates = tuple(self._dilations) self._patches = [] for tower in range(self._num_towers): with maybe_place_on_device(self._get_data_device(tower)): patches = tf.extract_image_patches( self._inputs[tower], ksizes=[1, filter_height, filter_width, 1], strides=self._strides, rates=rates, padding=self._padding) if self._patch_mask is not None: assert self._patch_mask.shape == self._filter_shape[0:-1] # This should work as intended due to broadcasting. patches *= self._patch_mask if self._has_bias: patches = append_homog(patches) self._patches.append(patches) return super(ConvDiagonalFactor, self).make_covariance_update_op( ema_decay, ema_weight) def _compute_new_cov(self, source, tower): patches = self._patches[tower] batch_size = utils.get_shape(patches)[0] outputs_grad = self._outputs_grads[source][tower] new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) new_cov /= tf.cast(batch_size, new_cov.dtype) return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". # It does this simply by computing a giant tensor containing all of these, # doing an entry-wise square, and them summing along the batch dimension. case_wise_gradients = tf.einsum("bijk,bijl->bkl", patches, outputs_grad) return tf.reduce_sum(tf.square(case_wise_gradients), axis=0) def _get_data_device(self, tower): return self._inputs[tower].device class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ def __init__(self, tensors, has_bias=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of list of Tensors, each of shape [batch_size, n]. The Tensors are typically either a layer's inputs or its output's gradients. The first list index is source, the second is tower. has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors self._one_hot_depth = getattr(self._tensors[0][0], "one_hot_depth", None) if self._one_hot_depth is not None: raise ValueError("Dense factors currently don't support 1-hot sparse " "data. Note that for input factors with such data, " "a diagonal approximation is exact (but the same is " "NOT true for output factors).") super(FullyConnectedKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_fckron_" + scope_string_from_params( tuple(nest.flatten(self._tensors)) + (self._has_bias,)) @property def _cov_shape(self): size = self._tensors[0][0].shape[1] + self._has_bias return [size, size] @property def _num_sources(self): return len(self._tensors) @property def _num_towers(self): return len(self._tensors[0]) @property def _dtype(self): return self._tensors[0][0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._tensors[source][tower])[0] def _compute_new_cov(self, source, tower): tensor = self._tensors[source][tower] if self._has_bias: tensor = append_homog(tensor) return compute_cov(tensor) def _get_data_device(self, tower): return self._tensors[0][tower].device class ConvInputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the input side of a convolutional layer. Estimates E[ a a^T ] where a is the inputs to a convolutional layer given example x. Expectation is taken over all examples and locations. Note that this is related to Omega in https://arxiv.org/abs/1602.01407 except that here we normalize by the number of locations (k). By setting the renormalization coefficient ("_renorm_coeff") in the block class to k we get the same overall block approximation from the paper. """ def __init__(self, inputs, filter_shape, padding, strides=None, dilation_rate=None, data_format=None, extract_patches_fn=None, has_bias=False, sub_sample_inputs=None, sub_sample_patches=None, patch_mask=None): """Initializes ConvInputKroneckerFactor. Args: inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., in_channels]. Inputs to layer. List index is tower. filter_shape: List of ints. Contains [..spatial_filter_size.., in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". strides: List of ints or None. Contains [..spatial_filter_strides..] if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_filter_strides, 1]. dilation_rate: List of ints or None. Rate for dilation along each spatial dimension if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. data_format: str or None. Format of input data. extract_patches_fn: str or None. Name of function that extracts image patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. sub_sample_inputs: `bool`. If True, then subsample the inputs from which the image patches are extracted. (Default: None) sub_sample_patches: `bool`, If `True` then subsample the extracted patches. (Default: None) patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed. (Default: None) """ self._inputs = inputs self._filter_shape = filter_shape self._strides = strides self._padding = padding self._dilation_rate = dilation_rate self._data_format = data_format self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias if sub_sample_inputs is None: self._sub_sample_inputs = _SUB_SAMPLE_INPUTS else: self._sub_sample_inputs = sub_sample_inputs if sub_sample_patches is None: self._sub_sample_patches = _SUB_SAMPLE_PATCHES else: self._sub_sample_patches = sub_sample_patches self._patch_mask = patch_mask super(ConvInputKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convinkron_" + scope_string_from_params( tuple(self._inputs) + tuple((self._filter_shape, self._strides, self._padding, self._dilation_rate, self._data_format, self._has_bias, self._patch_mask))) @property def _cov_shape(self): spatial_filter_shape = self._filter_shape[0:-2] in_channels = self._filter_shape[-2] size = np.prod(spatial_filter_shape) * in_channels + self._has_bias return [size, size] @property def _num_sources(self): return 1 @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype def _partial_batch_size(self, source=0, tower=0): assert source == 0 return utils.get_shape(self._inputs[tower])[0] def _compute_new_cov(self, source, tower): assert source == 0 inputs = self._inputs[tower] if self._sub_sample_inputs: batch_size = inputs.shape.as_list()[0] if batch_size is None: # dynamic case: batch_size = utils.get_shape(inputs)[0] # computes: int(math.ceil(batch_size # * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)) new_size = tf.cast( tf.ceil(tf.multiply(tf.cast(batch_size, dtype=tf.float32), _INPUTS_TO_EXTRACT_PATCHES_FACTOR)), dtype=utils.preferred_int_dtype()) else: # static case: new_size = int(math.ceil(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)) inputs = _random_tensor_gather(inputs, new_size) # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. if _USE_PATCHES_SECOND_MOMENT_OP: raise NotImplementedError # patches op is not available outside of Google, # sorry! You'll need to turn it off to proceed. else: if self._extract_patches_fn in [None, "extract_convolution_patches"]: patches = utils.extract_convolution_patches( inputs, self._filter_shape, padding=self._padding, strides=self._strides, dilation_rate=self._dilation_rate, data_format=self._data_format) elif self._extract_patches_fn == "extract_image_patches": assert inputs.shape.ndims == 4 assert len(self._filter_shape) == 4 assert len(self._strides) == 4, self._strides if self._dilation_rate is None: rates = [1, 1, 1, 1] else: rates = self._dilation_rate assert len(rates) == 4 assert rates[0] == rates[-1] == 1 patches = tf.extract_image_patches( inputs, ksizes=[1] + list(self._filter_shape[0:-2]) + [1], strides=self._strides, rates=rates, padding=self._padding) elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] assert self._filter_shape[0] == self._filter_shape[1] == 1 patches = utils.extract_pointwise_conv2d_patches( inputs, self._filter_shape, data_format=None) else: raise NotImplementedError(self._extract_patches_fn) if self._patch_mask is not None: assert self._patch_mask.shape == self._filter_shape[0:-1] # This should work as intended due to broadcasting. patches *= tf.reshape(self._patch_mask, [-1]) flatten_size = np.prod(self._filter_shape[0:-1]) # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), # where M = minibatch size, |T| = number of spatial locations, # |Delta| = number of spatial offsets, and J = number of input maps # for convolutional layer l. patches_flat = tf.reshape(patches, [-1, flatten_size]) # We append a homogenous coordinate to patches_flat if the layer has # bias parameters. This gives us [[A_l]]_H from the paper. if self._sub_sample_patches: patches_flat = _subsample_patches(patches_flat) if self._has_bias: patches_flat = append_homog(patches_flat) # We call compute_cov without passing in a normalizer. compute_cov uses # the first dimension of patches_flat i.e. M|T| as the normalizer by # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from # the paper but has a different scale here for consistency with # ConvOutputKroneckerFactor. # (Tilde omitted over A for clarity.) return compute_cov(patches_flat) def _get_data_device(self, tower): return self._inputs[tower].device class ConvInputMultiKF(ConvInputKroneckerFactor): def __init__(self, inputs, filter_shape, padding, num_uses, strides=None, dilation_rate=None, data_format=None, extract_patches_fn=None, has_bias=False, sub_sample_inputs=None, sub_sample_patches=None, patch_mask=None): super(ConvInputMultiKF, self).__init__(inputs, filter_shape, padding, strides=strides, dilation_rate=dilation_rate, data_format=data_format, extract_patches_fn=extract_patches_fn, has_bias=has_bias, sub_sample_inputs=sub_sample_inputs, sub_sample_patches=sub_sample_patches, patch_mask=patch_mask) self._num_uses = num_uses def _partial_batch_size(self, source=0, tower=0): # Note that some internal comptutations of "batch_size" done in the parent # class won't actually be the proper batch size. Instead, they will be # just "the thing to normalize the statistics by", essentially. This is okay # as we don't mix the two things up. return (super(ConvInputMultiKF, self)._partial_batch_size(source=source, tower=tower) // self._num_uses) class ConvInputSUAKroneckerFactor(FisherFactor): r"""Kronecker factor for the input side of a convolutional layer. Assumes activations across locations are uncorrelated. Check section 4.2 Theorem 4 in https://arxiv.org/pdf/1602.01407.pdf for further details on the assumptions. This is a computationally more efficient approximation, especially for very wide layers. """ def __init__(self, inputs, filter_shape, has_bias=False): """Initializes ConvInputSUAKroneckerFactor. If `ASSUME_ZERO_MEAN_ACTIVATIONS` is `True` then assumes activations zero mean and the contribution from `M(j) M(j')` term in Theorem 4 from https://arxiv.org/pdf/1602.01407.pdf is ignored. Args: inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., in_channels]. Inputs to layer. List index is tower. filter_shape: List of ints. Contains [..spatial_filter_size.., in_channels, out_channels]. Shape of convolution kernel. has_bias: bool. If True, appends 1 to mean activations. """ self._inputs = inputs self._filter_shape = filter_shape self._has_bias = has_bias self._kw_kh = np.prod(self._filter_shape[0:-2]) self._in_channels = self._filter_shape[-2] self._matpower_by_exp_and_damping = OrderedDict() # { (float, hashable): variable } self._matpower_registrations = set() # { (float, hashable) } self._damping_funcs_by_id = OrderedDict() # {hashable: lambda} self._damping_var_by_id = OrderedDict() if not ASSUME_ZERO_MEAN_ACTIVATIONS: self._cov_inv_mu_by_damping_id = OrderedDict() self._rank_one_update_scale_by_damping_id = OrderedDict() super(ConvInputSUAKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convinsuakron_" + scope_string_from_params( tuple(self._inputs) + tuple((self._filter_shape, self._has_bias))) @property def _cov_shape(self): """Returns a list with value [in_channels, in_channels]. NOTE: This does not return the shape of the full cov matrix. But returns the shape of the matrix which computes the covariance of the input channel activations under the assumption mentioned in Theorem 4 in https://arxiv.org/pdf/1602.01407.pdf. This does not include bias dimension and also includes only the `Sigma` term from Theorem 4 in the paper. """ return [self._in_channels, self._in_channels] @property def _num_sources(self): return 1 @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype @property def mu(self): return self._mu.value def _partial_batch_size(self, source=0, tower=0): assert source == 0 return utils.get_shape(self._inputs[tower])[0] def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) if damping_id not in self._damping_funcs_by_id: self._damping_funcs_by_id[damping_id] = damping_func return damping_id def get_inv_vars(self): inv_vars = [] inv_vars.extend(self._matpower_by_exp_and_damping.values()) return inv_vars def instantiate_cov_variables(self): """Makes the internal cov variable(s).""" super(ConvInputSUAKroneckerFactor, self).instantiate_cov_variables() # Create variables for computing the mean activations only if # `ASSUME_ZERO_MEAN_ACTIVATIONS` is set to `False`. Otherwise the # contribution from the second term in equation 35 in the paper # https://arxiv.org/pdf/1602.01407.pdf is ignored. if not ASSUME_ZERO_MEAN_ACTIVATIONS: with tf.variable_scope(self._var_scope): self._mu = utils.MovingAverageVariable( name="mu", shape=(self._in_channels, 1), # number of input channels. dtype=self._dtype, initializer=tf.zeros_initializer(), normalize_value=ZERO_DEBIAS) def make_covariance_update_op(self, ema_decay, ema_weight): """Constructs and returns the covariance update Op. Args: ema_decay: The exponential moving average decay (float or Tensor). ema_weight: float or Tensor. The weight to put on the newly computed values. This is typically 1.0 - ema_decay. Returns: An Op for updating the covariance Variable referenced by _cov and possibly updating mean activations. """ # The newly computed cov matrix is returned and assigned below to the # moving average. `new_cov` is required to compute mean activations. # Mean activations is given by last row and col of `new_cov. # Remove the last row and col from `new_cov`. new_cov = super(ConvInputSUAKroneckerFactor, self)._compute_total_new_cov() new_mu = new_cov[:-1, -1:] new_cov = new_cov[0:-1, 0:-1] if not ASSUME_ZERO_MEAN_ACTIVATIONS: new_cov = new_cov - tf.matmul(new_mu, new_mu, transpose_b=True) acc_mu_op = self._mu.add_to_average(new_mu, decay=ema_decay, weight=ema_weight) else: acc_mu_op = tf.no_op() if SUBTRACT_MEAN_CONTRIB_FROM_COV: new_cov = new_cov - tf.matmul(new_mu, new_mu, transpose_b=True) acc_cov_op = self._cov.add_to_average(new_cov, decay=ema_decay, weight=ema_weight) return tf.group(acc_cov_op, acc_mu_op) def _compute_new_cov(self, source, tower): assert source == 0 inputs = self._inputs[tower] # Reshape inputs to compute [in_channels, in_channels] shape cov. channel_inputs = tf.reshape(inputs, shape=(-1, self._in_channels)) # Append the bias dimension as we need this to calculate mean activations. channel_inputs = append_homog(channel_inputs) return compute_cov(channel_inputs) def register_matpower(self, exp, damping_func): """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower. Args: exp: float. The exponent to use in the matrix power. damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: return if exp != -1: raise ValueError("ConvInputSUAKroneckerFactor supports only" "matrix inversion") damping_id = self._register_damping(damping_func) if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) def _compute_sm_rank_one_update_quants(self, exp, damping_id, damping_value): """Returns tensors to compute Fisher inv using Sherman-Morrison formula.""" cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)] cov_inv_mu = tf.matmul(cov_inv, self.mu) hatmu_t_cov_inv_hatmu = self._kw_kh * tf.squeeze( tf.matmul(self.mu, cov_inv_mu, transpose_a=True)) if self._has_bias: tildemu_t_cov_inv_tildemu = hatmu_t_cov_inv_hatmu + (1. / damping_value) return cov_inv_mu, (1. / (1. + tildemu_t_cov_inv_tildemu)) else: return cov_inv_mu, (1. / (1. + hatmu_t_cov_inv_hatmu)) def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov (except when exp == 1). if exp == 1: return self._make_cov_linear_operator( damping=tf.cast(damping_func(), dtype=self._dtype)) elif exp == -1: damping_id = graph_func_to_id(damping_func) cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)] damping_value = self._damping_var_by_id[damping_id] # Replicates the in_channels * in_channels cov inverse matrix. # Note that in this function the replications are not done explicitly. # They are done using tf.linalg ops and hence they are computationally # efficient. quant_1 = tf.linalg.LinearOperatorKronecker([ tf.linalg.LinearOperatorFullMatrix( cov_inv, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True), tf.linalg.LinearOperatorIdentity( num_rows=self._kw_kh, dtype=self._dtype) ]) # If a bias dimension needs to be appended then we need to expand # scaled_cov_inv_mu and assign `1` to the last dimension. Also # we need to append inverse of damping constant (1 * 1 matrix) to # to the replicated cov inverse matrix. if self._has_bias: bias_operator = tf.linalg.LinearOperatorFullMatrix( [[1. / damping_value]], is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) cov_inv_kron_identity_operator = tf.linalg.LinearOperatorBlockDiag( [quant_1, bias_operator]) if not ASSUME_ZERO_MEAN_ACTIVATIONS: cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id] scale = self._rank_one_update_scale_by_damping_id[damping_id] # Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last # dim and then reshape. mean_update = ( tf.expand_dims( append_homog( tf.reshape(tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1,)), homog_value=(1. / damping_value)), axis=1)) else: cov_inv_kron_identity_operator = quant_1 if not ASSUME_ZERO_MEAN_ACTIVATIONS: cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id] scale = self._rank_one_update_scale_by_damping_id[damping_id] # Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last # dim and then reshape. mean_update = tf.reshape( tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1, 1)) if ASSUME_ZERO_MEAN_ACTIVATIONS: return cov_inv_kron_identity_operator else: # To include the contribution from the mean activations we need to # low rank update op. Note the Sherman Morrison formula requires # negative of (mean_update * mean_update^T) / scale term to be added. # In order to achieve this using `LinearOperatorLowRankUpdate` set `v` # to negative of mean update vector multiplied by scale. return tf.linalg.LinearOperatorLowRankUpdate( cov_inv_kron_identity_operator, mean_update, v=-scale * mean_update, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) else: raise ValueError("ConvInputSUAKroneckerFactor only supports" "computing inverse of cov matrix.") def make_inverse_update_ops(self): """Creates and return update ops for registered computations.""" inverse_ops = [] for (exp, damping_id), matpower in self._matpower_by_exp_and_damping.items(): assert exp == -1 damping = tf.cast(self._damping_funcs_by_id[damping_id](), self._dtype) damping_assign_op = utils.smart_assign( self._damping_var_by_id[damping_id], damping) inverse_op = utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping)) inverse_ops.append(damping_assign_op) if not ASSUME_ZERO_MEAN_ACTIVATIONS: with tf.control_dependencies([inverse_op]): (cov_inv_mu, rank_one_update_scale) = self._compute_sm_rank_one_update_quants( exp, damping_id, damping) inverse_ops.append( utils.smart_assign(self._cov_inv_mu_by_damping_id[damping_id], cov_inv_mu)) inverse_ops.append( utils.smart_assign( self._rank_one_update_scale_by_damping_id[damping_id], rank_one_update_scale)) else: inverse_ops.append(inverse_op) return inverse_ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests return self.get_matpower(-1, damping_func) def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" for (exp, damping_id) in self._matpower_registrations: if exp != -1.: raise ValueError("ConvInputSUAKroneckerFactor only supports inverse" "computation") exp_string = scalar_or_tensor_to_string(exp) damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): matpower = tf.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower self._damping_var_by_id[damping_id] = tf.get_variable( "damping_var_{}_{}".format(exp_string, damping_string), initializer=tf.zeros_initializer(), shape=(), trainable=False, dtype=self._dtype, use_resource=True) if not ASSUME_ZERO_MEAN_ACTIVATIONS: self._cov_inv_mu_by_damping_id[damping_id] = tf.get_variable( "cov_inv_mu_{}_{}".format(exp_string, damping_string), initializer=tf.zeros_initializer(), shape=(self._in_channels, 1), trainable=False, dtype=self._dtype, use_resource=True) self._rank_one_update_scale_by_damping_id[damping_id] = tf.get_variable( "rank_one_update_scale_{}_{}".format(exp_string, damping_string), initializer=tf.zeros_initializer(), shape=(), trainable=False, dtype=self._dtype, use_resource=True) def _make_cov_linear_operator(self, damping=None): """Returns cov as a linear operator. Args: damping: Damping value tensor. If `damping` is not None then returns damped covariance matrix. Returns: tf.linalg.LinearOperator instance. """ if damping is not None: cov = self.cov + damping * tf.eye(self._cov_shape[0], dtype=self._dtype) else: cov = self.cov cov_operator = tf.linalg.LinearOperatorKronecker([ tf.linalg.LinearOperatorFullMatrix( cov, is_self_adjoint=True, is_square=True), tf.linalg.LinearOperatorIdentity( num_rows=self._kw_kh, dtype=self._dtype) ]) if self._has_bias: bias_value = damping if damping is not None else 0. bias_operator = tf.linalg.LinearOperatorFullMatrix([[bias_value]], is_self_adjoint=True, is_square=True) cov_operator = tf.linalg.LinearOperatorBlockDiag( [cov_operator, bias_operator]) if ASSUME_ZERO_MEAN_ACTIVATIONS: return cov_operator else: # self.mu kron 1's vec is computed below by tiling mu. hatmu = tf.tile(self.mu, [1, self._kw_kh]) if self._has_bias: tildemu = append_homog(tf.reshape(hatmu, (-1,))) mean_update = tf.expand_dims(tildemu, axis=1) else: mean_update = tf.reshape(hatmu, (-1, 1)) return tf.linalg.LinearOperatorLowRankUpdate( cov_operator, mean_update, is_self_adjoint=True, is_square=True) def get_cov_as_linear_operator(self): return self._make_cov_linear_operator() def get_cholesky(self, damping_func): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky factorization") def get_cholesky_inverse(self, damping_func): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky inverse computation") def register_cholesky(self): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky factorization") def register_cholesky_inverse(self): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky inverse computation") def _get_data_device(self, tower): return self._inputs[tower].device class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the output side of a convolutional layer. Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over all examples and locations. Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ def __init__(self, outputs_grads, data_format=None): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: List of list of Tensors. Each Tensor is of shape [batch_size, ..spatial_input_size.., out_channels]. First list index is source, the second is tower. data_format: None or str. Format of outputs_grads. Raises: ValueError: If channels are not final dimension. """ if not utils.is_data_format_channel_last(data_format): raise ValueError("Channel must be last.") self._out_channels = outputs_grads[0][0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convoutkron_" + scope_string_from_params( nest.flatten(self._outputs_grads)) @property def _cov_shape(self): size = self._out_channels return [size, size] @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._outputs_grads[0]) @property def _dtype(self): return self._outputs_grads[0][0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def _compute_new_cov(self, source, tower): outputs_grad = self._outputs_grads[source][tower] # reshaped_tensor below is the matrix DS_l defined in the KFC paper # (tilde omitted over S for clarity). It has shape M|T| x I, where # M = minibatch size, |T| = number of spatial locations, and # I = number of output maps for convolutional layer l. reshaped_tensor = tf.reshape(outputs_grad, [-1, self._out_channels]) # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l # as defined in the paper, with shape I x I. # (Tilde omitted over S for clarity.) return compute_cov(reshaped_tensor) def _get_data_device(self, tower): return self._outputs_grads[0][tower].device class ConvOutputMultiKF(ConvOutputKroneckerFactor): def __init__(self, outputs_grads, num_uses, data_format=None): super(ConvOutputMultiKF, self).__init__(outputs_grads, data_format=data_format) self._num_uses = num_uses def _partial_batch_size(self, source=0, tower=0): # Note that some internal comptutations of "batch_size" done in the parent # class won't actually be the proper batch size. Instead, they will be # just "the thing to normalize the statistics by", essentially. This is okay # as we don't mix the two things up. return (super(ConvOutputMultiKF, self)._partial_batch_size(source=source, tower=tower) // self._num_uses) class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): """Kronecker factor for a fully connected layer used multiple times.""" def __init__(self, tensors, num_uses=None, has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: tensors: List of list of Tensors of shape, each of shape [num_uses * batch_size, n], and is a reshape version of a Tensor of shape [num_uses, batch_size, n]. Each of these tensors is usually a layer's inputs or its output's gradients. The first list index is sources, the second is towers. num_uses: int. The number of time-steps / uses. has_bias: bool. If True, '1' is appended to each row. """ self._num_uses = num_uses self._cov_dt1 = None self._acc_cov_dt1 = None self._make_cov_dt1 = False self._option1quants_by_damping = OrderedDict() self._option2quants_by_damping = OrderedDict() self._option1quants_registrations = set() self._option2quants_registrations = set() super(FullyConnectedMultiKF, self).__init__(tensors=tensors, has_bias=has_bias) @property def _num_timesteps(self): return self._num_uses def _partial_batch_size(self, source=0, tower=0): shape = utils.get_shape(self._tensors[source][tower]) if len(shape) == 2: # the folded case return shape[0] // self._num_timesteps elif len(shape) == 3: return shape[1] # batch is the second dim @property def _var_scope(self): return "ff_fc_multi_" + scope_string_from_params( tuple(nest.flatten(self._tensors)) + (self._num_timesteps, self._has_bias,)) def get_inv_vars(self): inv_vars = super(FullyConnectedMultiKF, self).get_inv_vars() inv_vars.extend(self._option1quants_by_damping.values()) inv_vars.extend(self._option2quants_by_damping.values()) return inv_vars def make_covariance_update_op(self, ema_decay, ema_weight): op = super(FullyConnectedMultiKF, self).make_covariance_update_op( ema_decay, ema_weight) if self._cov_dt1 is not None: new_cov_dt1_contribs = [] for source in range(self._num_sources): for tower in range(self._num_towers): with maybe_place_on_device(self._get_data_device(tower)): new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, tower)) new_cov_dt1 = (tf.add_n(new_cov_dt1_contribs) / float(self._num_towers)) # See comments in FisherFactor.make_covariance_update_op() for details. new_cov_dt1 = utils.all_average(new_cov_dt1) op2 = self._cov_dt1.add_to_average(new_cov_dt1, decay=ema_decay, weight=ema_weight) # TODO(b/69112164): # It's important that _cov and _cov_dt1 remain consistent with each # other while the inverse ops are happening. How can we ensure this? # We will need to add explicit synchronization for this to # work with asynchronous training. op = tf.group(op, op2) return op def _compute_new_cov(self, source, tower): tensor = self._tensors[source][tower] if len(tensor.shape) == 3: tensor = tf.reshape(tensor, [-1, tensor.shape[2]]) if self._has_bias: tensor = append_homog(tensor) return compute_cov(tensor) def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring tensor = self._tensors[source][tower] if len(tensor.shape) == 3: tensor = tf.reshape(tensor, [-1, tensor.shape[2]]) if self._has_bias: # This appending is technically done twice (the other time is for # _compute_new_cov()) tensor = append_homog(tensor) total_len = utils.get_shape(tensor)[0] batch_size = total_len // self._num_timesteps tensor_present = tensor[:-batch_size, :] tensor_future = tensor[batch_size:, :] # We specify a normalizer for this computation to ensure a PSD Fisher # block estimate. This is equivalent to padding with zeros, as was done # in Section B.2 of the appendix. return compute_cov( tensor_future, tensor_right=tensor_present, normalizer=total_len) @property def _cov_shape(self): shape = self._tensors[0][0].shape if len(shape) == 2: size = shape[1] + self._has_bias elif len(shape) == 3: size = shape[2] + self._has_bias return [size, size] def _get_data_device(self, tower): return self._tensors[0][tower].device @property def _vec_shape(self): size = self._tensors[0][0].shape[1] + self._has_bias return [size] def get_option1quants(self, damping_func): damping_id = graph_func_to_id(damping_func) return self._option1quants_by_damping[damping_id] def get_option2quants(self, damping_func): damping_id = graph_func_to_id(damping_func) return self._option2quants_by_damping[damping_id] @property def cov_dt1(self): assert self._cov_dt1 is not None return self._cov_dt1.value def get_cov_vars(self): cov_vars = super(FullyConnectedMultiKF, self).get_cov_vars() if self._make_cov_dt1: cov_vars += [self.cov_dt1] return cov_vars def register_cov_dt1(self): self._make_cov_dt1 = True def instantiate_cov_variables(self): super(FullyConnectedMultiKF, self).instantiate_cov_variables() assert self._cov_dt1 is None if self._make_cov_dt1: with tf.variable_scope(self._var_scope): self._cov_dt1 = utils.MovingAverageVariable( name="cov_dt1", shape=self._cov_shape, dtype=self._dtype, initializer=tf.zeros_initializer(), normalize_value=ZERO_DEBIAS) def register_option1quants(self, damping_func): damping_id = self._register_damping(damping_func) if damping_id not in self._option1quants_registrations: self._option1quants_registrations.add(damping_id) def register_option2quants(self, damping_func): damping_id = self._register_damping(damping_func) if damping_id not in self._option2quants_registrations: self._option2quants_registrations.add(damping_id) def instantiate_inv_variables(self): super(FullyConnectedMultiKF, self).instantiate_inv_variables() for damping_id in self._option1quants_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. with tf.variable_scope(self._var_scope): Lmat = tf.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) psi = tf.get_variable( "psi_damp{}".format(damping_string), initializer=tf.ones_initializer(), shape=self._vec_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._option1quants_by_damping self._option1quants_by_damping[damping_id] = (Lmat, psi) for damping_id in self._option2quants_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. with tf.variable_scope(self._var_scope): Pmat = tf.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) Kmat = tf.get_variable( # pylint: disable=invalid-name "Kmat_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) mu = tf.get_variable( "mu_damp{}".format(damping_string), initializer=tf.ones_initializer(), shape=self._vec_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._option2quants_by_damping self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" # TODO(b/69918258): Add correctness tests for this method. # pylint: disable=invalid-name ops = [] if (len(self._option1quants_by_damping) + len(self._option2quants_by_damping)): # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from # the pseudo-code in the original paper. Because the computations for # the A and G case are essentially the same they can both be performed by # the same class (this one). C1 = self.cov_dt1 # Get the eigendecomposition of C0 (= self.cov) eigen_e, eigen_V = self.get_eigendecomp() # TODO(b/69678661): Note, there is an implicit assumption here that C1 # and C0 (as represented here by its eigen-decomp) are consistent. This # could fail to be the case if self._cov and self._cov_dt1 are not updated # consistently, or are somehow read between or during the cov updates. # Can this possibly happen? Is there a way to prevent it? for damping_id, (Lmat_var, psi_var) in self._option1quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() damping = tf.cast(damping, self._dtype) invsqrtC0 = tf.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) # Might need to enforce symmetry lost due to numerical issues. invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0 # The following line imposes the symmetry assumed by "Option 1" on C1. # Strangely the code can work okay with this line commented out, # depending on how psd_eig is defined. I'm not sure why. C1 = (C1 + tf.transpose(C1)) / 2.0 # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) hPsi = tf.matmul(tf.matmul(invsqrtC0, C1), invsqrtC0) # Compute the decomposition U*diag(psi)*U^T = hPsi psi, U = utils.posdef_eig(hPsi) # L = C0^(-1/2) * U Lmat = tf.matmul(invsqrtC0, U) ops.append(utils.smart_assign(Lmat_var, Lmat)) ops.append(utils.smart_assign(psi_var, psi)) for damping_id, (Pmat_var, Kmat_var, mu_var) in self._option2quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() damping = tf.cast(damping, self._dtype) # compute C0^(-1/2) invsqrtC0 = tf.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) # Might need to enforce symmetry lost due to numerical issues. invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0 # Compute the product C0^(-1/2) * C1 invsqrtC0C1 = tf.matmul(invsqrtC0, C1) # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) hPsi = tf.matmul(invsqrtC0C1, invsqrtC0) # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi # Note that we using the notation mu instead of "m" for the eigenvalues. # Instead of computing the product hPsi^T * hPsi and then doing an # eigen-decomposition of this we just compute the SVD of hPsi and then # square the singular values to get the eigenvalues. For a justification # of this approach, see: # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition sqrtmu, _, E = tf.svd(hPsi) mu = tf.square(sqrtmu) # Mathematically, the eigenvalues should not should not exceed 1.0, but # due to numerical issues, or possible issues with inconsistent # values of C1 and (the eigen-decomposition of) C0 they might. So # we enforce this condition. mu = tf.minimum(mu, 1.0) # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) Pmat = tf.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) # K = C_0^(-1/2) * E Kmat = tf.matmul(invsqrtC0, E) ops.append(utils.smart_assign(Pmat_var, Pmat)) ops.append(utils.smart_assign(Kmat_var, Kmat)) ops.append(utils.smart_assign(mu_var, mu)) ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() return [tf.group(*ops)] # pylint: enable=invalid-name ================================================ FILE: kfac/python/ops/kfac_utils/__init__.py ================================================ ================================================ FILE: kfac/python/ops/kfac_utils/async_inv_cov_update_kfac_opt.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implementation of KFAC which runs cov and inv ops asynchronously.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import threading # Dependency imports import tensorflow.compat.v1 as tf from kfac.python.ops import optimizer _MAX_NUM_COV_INV_UPDATE_THREADS = 10 class AsyncInvCovUpdateKfacOpt(optimizer.KfacOptimizer): """Provides functionality to run cov and inv ops asynchronously. The update ops are placed on devices in a round robin manner. These ops are run asynchronously in the sense that the training op and cov and inv matrix matrix computations are run independently of each other. The cov and inv ops are run in background by threads. Example usage: opt = DedicatedInvCovUpdateKfacOpt(cov_devices=["/gpu:0"], inv_devices=["/gpu:1"]) train_op = opt.minimize(loss) with tf.Session() as sess: opt.run_cov_inv_ops(sess) for _ in range(100): sess.run([train_op]) opt.stop_cov_inv_ops(sess) """ def __init__(self, cov_devices, inv_devices, num_cov_inv_update_threads=None, **kwargs): """Initializes AsyncInvCovUpdateKfacOpt. See the docstring for `KfacOptimizer` class (in optimizer.py) for complete list of arguments (there are many!). Args: cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. num_cov_inv_update_threads: `int`, Number of parallel computations of inverse and covariance ops. If a value is not passed then the number of threads will be set to half of length of number of ops to run asynchronously (Capped at `_MAX_NUM_COV_INV_UPDATE_THREADS`). (Default: None) **kwargs: Arguments to `KfacOptimizer` class. """ self.next_op = None self._coord = None self._num_cov_inv_update_threads = num_cov_inv_update_threads self._threads = None super(AsyncInvCovUpdateKfacOpt, self).__init__( placement_strategy="round_robin", **kwargs) def _make_ops(self, update_thunks): return [thunk() for thunk in update_thunks] def apply_gradients(self, grads_and_vars, global_step=None, name=None): cov_update_thunks, inv_update_thunks = self.make_vars_and_create_op_thunks() apply_grads = super(AsyncInvCovUpdateKfacOpt, self).apply_gradients( grads_and_vars=grads_and_vars, global_step=global_step, name=name) self._set_up_op_name_queue( self._make_ops(cov_update_thunks + inv_update_thunks)) return apply_grads def run_cov_inv_ops(self, sess): """Starts threads to run covariance and inverse ops.""" self._coord = tf.train.Coordinator() self._threads = [ threading.Thread(target=self._run_ops, args=( (sess,) )) for _ in range(self._num_cov_inv_update_threads) ] for t in self._threads: t.start() def _run_ops(self, sess): """Runs the covariance and inverse ops. Each thread gets the next op name to run from the shared dataset that is created in `_set_up_op_name_queue` method. The opname is mapped to the op which is run in thread context. Args: sess: `tf.Session` instance. """ while not self._coord.should_stop(): next_op_name = sess.run(self._next_op_name).decode("ascii") next_op = self._ops_by_name[next_op_name] sess.run(next_op) def stop_cov_inv_ops(self, sess): """Signals coordinator to stop and waits for threads to terminate.""" self._coord.request_stop() self._coord.join(self._threads) def _set_up_op_name_queue(self, ops_to_run): """Sets up a queue of op names. Convert the names of ops to run to tensors and creates a dataset of names. The op name tensors in the Dataset are repeated indefinitely. Running `self._next_op_name` returns the name of the next op to execute. Args: ops_to_run: `List` of ops to run asynchronously. """ self._num_cov_inv_update_threads = self._num_cov_inv_update_threads or max( int(len(ops_to_run) / 2), _MAX_NUM_COV_INV_UPDATE_THREADS) self._ops_by_name = {op.name: op for op in ops_to_run} op_names = tf.convert_to_tensor(list(sorted(op.name for op in ops_to_run))) op_names_dataset = tf.data.Dataset.from_tensor_slices(op_names).repeat() self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() ================================================ FILE: kfac/python/ops/kfac_utils/data_reader.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Reads variable size batches of data from a data set and stores read data. `VariableBatchReader` reads variable size data from a dataset. `CachedDataReader` on top of `VariableBatchReader` adds functionality to store the read batch for use in the next session.run() call. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf def _slice_data(stored_data, size): return [data[:size] for data in stored_data] class VariableBatchReader(object): """Read data of varying batch sizes from a data set.""" def __init__(self, dataset, max_batch_size): """Initializes class. Args: dataset: List of Tensors representing the dataset, shuffled, repeated, and batched into mini-batches of size at least `max_batch_size`. In other words it should be reshuffled at each session.run call. This can be done with the tf.data package using the construction demonstrated in load_mnist() function in examples/autoencoder_auto_damping.py. max_batch_size: `int`. Maximum batch size of the data that can be retrieved from the data set. """ self._dataset = dataset self._max_batch_size = max_batch_size def __call__(self, batch_size): """Reads `batch_size` data. Args: batch_size: Tensor of type `int32`, batch size of the data to be retrieved from the dataset. `batch_size` should be less than or equal to `max_batch_size`. Returns: Read data, An iterable of tensors with batch size equal to `batch_size`. """ check_size = tf.assert_less_equal( batch_size, tf.convert_to_tensor(self._max_batch_size, dtype=tf.int32), message='Data set read failure, Batch size greater than max allowed.' ) with tf.control_dependencies([check_size]): return _slice_data(self._dataset, batch_size) class CachedDataReader(VariableBatchReader): """Provides functionality to store variable batch size data.""" def __init__(self, dataset, max_batch_size): """Initializes class and creates variables for storing previous batch. Args: dataset: List of Tensors representing the dataset, shuffled, repeated, and batched into mini-batches of size at least `max_batch_size`. In other words it should be reshuffled at each session.run call. This can be done with the tf.data package using the construction demonstrated in load_mnist() function in examples/autoencoder_auto_damping.py. max_batch_size: `int`. Maximum batch size of the data that can be retrieved from the data set. """ super(CachedDataReader, self).__init__(dataset, max_batch_size) with tf.variable_scope('cached_data_reader'): self._cached_batch_storage = [ tf.get_variable( name='{}{}'.format('cached_batch_storage_', i), shape=[max_batch_size]+ var.shape.as_list()[1:], dtype=var.dtype, trainable=False, use_resource=True) for i, var in enumerate(self._dataset) ] self._cached_batch_size = tf.get_variable( name='cached_batch_size', shape=(), dtype=tf.int32, trainable=False, use_resource=True) self._cached_batch = _slice_data(self._cached_batch_storage, self._cached_batch_size) def __call__(self, batch_size): """Reads `batch_size` data and stores the read batch. Args: batch_size: Tensor of type `int32`, batch size of the data to be retrieved from the dataset. `batch_size` should be less than or equal to `max_batch_size`. Returns: Read data, An iterable of tensors with batch size equal to `batch_size`. """ sliced_data = super(CachedDataReader, self).__call__(batch_size) # We need to make sure we read the cached batch before we update it! with tf.control_dependencies(self._cached_batch): batch_size_assign_op = self._cached_batch_size.assign(batch_size) data_assign_ops = [ prev[:batch_size].assign(cur) # yes, this actually works for prev, cur in zip(self._cached_batch_storage, sliced_data) ] with tf.control_dependencies(data_assign_ops + [batch_size_assign_op]): return [tf.identity(sdata) for sdata in sliced_data] @property def cached_batch(self): return self._cached_batch ================================================ FILE: kfac/python/ops/kfac_utils/data_reader_alt.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Reads variable size batches of data from a data set and stores read data. `VariableBatchReader` reads variable size data from a dataset. `CachedDataReader` on top of `VariableBatchReader` adds functionality to store the read batch for use in the next session.run() call. This file is similar to data_reader.py but uses an alternative implementation that requires the whole dataset to be passed in. This will often be faster than using the original implementation with a very large max_batch_size. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf def _extract_data(tensor_list, indices): return [tf.gather(tensor, indices, axis=0) for tensor in tensor_list] class VariableBatchReader(object): """Read data of varying batch sizes from a data set.""" def __init__(self, dataset, num_examples): """Initializes class. Args: dataset: List of Tensors. These must remain constant across session.run calls, unlike the version of VariableBatchReader in data_reader.py. num_examples: The number of examples in the data set (i.e. dimension 0 of the elements of `dataset`). """ self._dataset = dataset self._num_examples = num_examples self._indices = None def __call__(self, batch_size): """Reads `batch_size` data. Args: batch_size: Tensor of type `int32`. Batch size of the data to be retrieved from the dataset. `batch_size` should be less than or equal to the number of examples in the dataset. Returns: Read data, a list of Tensors with batch size equal to `batch_size`. """ check_size = tf.assert_less_equal( batch_size, tf.convert_to_tensor(self._num_examples, dtype=tf.int32), message='Data set read failure, batch_size > num_examples.' ) with tf.control_dependencies([check_size]): self._indices = tf.random.shuffle( tf.range(self._num_examples, dtype=tf.int32)) return _extract_data(self._dataset, self._indices[:batch_size]) class CachedDataReader(VariableBatchReader): """Provides functionality to store variable batch size data.""" def __init__(self, dataset, num_examples): """Initializes class and creates variables for storing previous batch. Args: dataset: List of Tensors. These must remain constant across session.run calls, unlike the version of VariableBatchReader in data_reader.py. num_examples: The number of examples in the data set (i.e. dimension 0 of the elements of `dataset`). """ super(CachedDataReader, self).__init__(dataset, num_examples) self._cached_batch_indices = tf.get_variable( name='cached_batch_indices', shape=[self._num_examples], dtype=tf.int32, trainable=False, use_resource=True) self._cached_batch_size = tf.get_variable( name='cached_batch_size', shape=(), dtype=tf.int32, trainable=False, use_resource=True) self._cached_batch = _extract_data( self._dataset, self._cached_batch_indices[:self._cached_batch_size]) def __call__(self, batch_size): """Reads `batch_size` data and stores the read batch. Args: batch_size: Tensor of type `int32`, batch size of the data to be retrieved from the dataset. `batch_size` should be less than or equal to `max_batch_size`. Returns: Read data, An iterable of tensors with batch size equal to `batch_size`. """ tensor_list = super(CachedDataReader, self).__call__(batch_size) with tf.control_dependencies(self._cached_batch): indices_assign_op = self._cached_batch_indices.assign(self._indices) batch_size_assign_op = tf.assign(self._cached_batch_size, batch_size) with tf.control_dependencies([indices_assign_op, batch_size_assign_op]): return [tf.identity(tensor) for tensor in tensor_list] @property def cached_batch(self): return self._cached_batch ================================================ FILE: kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implementation of KFAC which runs covariance and inverse ops periodically. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports from absl import logging import tensorflow.compat.v1 as tf from kfac.python.ops import optimizer from kfac.python.ops import utils class PeriodicInvCovUpdateKfacOpt(optimizer.KfacOptimizer): """Provides functionality to run covariance and inverse ops periodically. Creates KFAC optimizer with a `placement strategy`. Also runs the covariance and inverse ops periodically. The base class does not provide a mechanism to automatically construct and run the covariance and inverse ops, they must be created and run manually using make_vars_and_create_op_thunks or create_ops_and_vars_thunks. This class provides functionality to create these ops and runs them periodically whenever optimizer.minimize op is run. The inverse ops are run `invert_every` iterations and covariance statistics are updated `cov_update_every` iterations. Ideally set the `invert_every` to a multiple of `cov_update_every` so that the inverses are computed after the covariance is updated. The higher the multiple more the delay in using the computed covariance estimates in the KFAC update step. Also computing the statistics and inverses periodically saves on computation cost and a "reasonable" value often does not show any perforamnce degradation compared to computing these quantitites every iteration. """ def __init__(self, invert_every=10, cov_update_every=1, num_burnin_steps=0, **kwargs): """Initializes a PeriodicInvCovUpdateKfacOptimizer object. See the docstring for `KfacOptimizer` class (in optimizer.py) for complete list of arguments (there are many!). Please keep in mind that while the K-FAC code loosely conforms to TensorFlow's Optimizer API, it can't be used naively as a "drop in replacement" for basic classes like MomentumOptimizer. Using it properly with SyncReplicasOptimizer, for example, requires special care. See the various examples in the "examples" directory for a guide about how to use K-FAC in various contexts and various systems, like TF-Estimator. See in particular the convnet example. google/examples also contains an example using TPUEstimator. Note that not all use cases will work with PeriodicInvCovUpdateKfacOptimizer. Sometimes you will have to use the base KfacOptimizer which provides more fine-grained control over ops. Other times you might want to use one of the other subclassed optimizers like AsyncInvCovUpdateKfacOpt. Args: invert_every: int. The inversion ops are run once every `invert_every` executions of the training op. (Default: 10) cov_update_every: int. The 'covariance update ops' are run once every `covariance_update_every` executions of the training op. (Default: 1) num_burnin_steps: int. For the first `num_burnin_steps` steps the optimizer will only perform cov updates. Note: this doesn't work with CrossShardOptimizer, since the custom minimize method implementation will be ignored, or with MirroredStrategy, due to behavior of conditional parameter updates with multiple replicas. (Default: 0) **kwargs: Arguments to `KfacOptimizer` class. Raises: ValueError: if num_burnin_steps is non-zero and MirroredStrategy is being used. """ if "cov_ema_decay" in kwargs: kwargs["cov_ema_decay"] = kwargs["cov_ema_decay"]**cov_update_every super(PeriodicInvCovUpdateKfacOpt, self).__init__(**kwargs) self._invert_every = invert_every self._cov_update_every = cov_update_every self._num_burnin_steps = num_burnin_steps self._made_vars_already = False if self._adapt_damping: if self._damping_adaptation_interval % self._invert_every != 0: logging.warning("WARNING: damping_adaptation_interval isn't divisible " "by invert_every.") if (tf.distribute.has_strategy() and tf.distribute.get_replica_context()): strategy = tf.distribute.get_strategy() if (isinstance(strategy, tf.distribute.MirroredStrategy) and self._num_burnin_steps > 0): raise ValueError("num_burnin_steps must be 0 with MirroredStrategy.") with tf.variable_scope(self.get_name()): self._burnin_counter = tf.get_variable( "burnin_counter", dtype=tf.int64, shape=(), trainable=False, initializer=tf.zeros_initializer, use_resource=True, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) def minimize(self, loss, global_step=None, var_list=None, gate_gradients=tf.train.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=True, name=None, grad_loss=None, **kwargs): # This method has the same general arguments as the minimize methods in # standard optimizers do. if not self._made_vars_already: cov_update_thunks, _ = self.make_vars_and_create_op_thunks() else: (_, cov_update_thunks, _, _) = self.create_ops_and_vars_thunks() self._made_vars_already = True def update_cov_and_burnin_counter(): cov_update = tf.group(*(thunk(should_decay=False) for thunk in cov_update_thunks)) burnin_counter_update = self._burnin_counter.assign( self._burnin_counter + 1) return tf.group(cov_update, burnin_counter_update) def super_minimize(): return super(PeriodicInvCovUpdateKfacOpt, self).minimize( loss, global_step=global_step, var_list=var_list, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, grad_loss=grad_loss, **kwargs) if self._num_burnin_steps == 0: return super_minimize() else: return tf.cond(self._burnin_counter < self._num_burnin_steps, update_cov_and_burnin_counter, super_minimize) def apply_gradients(self, grads_and_vars, global_step=None, name=None): with tf.control_dependencies([self.kfac_update_ops()]): return super(PeriodicInvCovUpdateKfacOpt, self).apply_gradients( grads_and_vars=grads_and_vars, global_step=global_step, name=name) def kfac_update_ops(self): """Sets up the KFAC factor update ops. Returns: An op that when run will run the update ops at their update frequencies. """ # This if-statement is a trick/hack to maintain compatibility with # CrossShardOptimizer or other optimizers that might not call our # custom minimize() method (that would otherwise always make the variables). if not self._made_vars_already: (cov_update_thunks, inv_update_thunks) = self.make_vars_and_create_op_thunks() logging.warning("It looks like apply_gradients() was called before " "minimze() was called. This is not recommended, and you " "should avoid using optimizer wrappers like " "CrossShardOptimizer with K-FAC that try to bypass the " "minimize() method. The burn-in feature won't work when " "the class is used this way, for example. And K-FAC does " "its own cross-relica syncronization.") else: (_, cov_update_thunks, _, inv_update_thunks) = self.create_ops_and_vars_thunks() should_do_cov_updates = tf.equal(tf.mod(self.counter, self._cov_update_every), 0) maybe_cov_updates = utils.smart_cond( should_do_cov_updates, lambda: tf.group(*(thunk() for thunk in cov_update_thunks)), tf.no_op) maybe_pre_update_adapt_damping = self.maybe_pre_update_adapt_damping() with tf.control_dependencies([maybe_cov_updates, maybe_pre_update_adapt_damping]): should_do_inv_updates = tf.equal(tf.mod(self.counter, self._invert_every), 0) maybe_inv_updates = utils.smart_cond( should_do_inv_updates, lambda: tf.group(*(thunk() for thunk in inv_update_thunks)), tf.no_op) return maybe_inv_updates ================================================ FILE: kfac/python/ops/layer_collection.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Registry for layers and their parameters/variables. This represents the collection of all layers in the approximate Fisher information matrix to which a particular FisherBlock may belong. That is, we might have several layer collections for one TF graph (if we have multiple K-FAC optimizers being used, for example.) The model and loss function are registered using the register_XXX() methods. A subset of the layer types can be handled with the auto_register_layers() method. Note that the data formats in the docstrings for the register_XXX() methods must be strictly adhered to. So for example, if a method asks for a Tensor of shape [batch_size, ...], then the first dimension must be the batch size and nothing else. And the tensors must contain actual data, not a mixture of real and fake data / zeros generated by mini-batch padding, for example. (Padding is only fine if it's treated as regular data by both your model and loss function. e.g. adding "blank tokens" at the end of a sequence which the model is still expected to predict.) If a method asks for the parameters of a layer then they must be the actual variable object(s) for said parameters, not a tensor formed by reshaping, re-casting, or tranposing its value. If the internal data format used by your model isn't natively supported by this system, you shouldn't try to crow-bar the arguments of the registration methods until they seem to fit. Although the K-FAC code tries to protect against some common mistakes, it may often seem to run fine with incorrect registrations, generating no exceptions or errors. But this will almost certainly lead to (potentially severe) underperformance of the method. If you have model code that doesn't represent tensors in the format expected by K-FAC, one thing you can try is introducing transformations that perform the conversion back and forth. But make sure the format that you convert to is actually valid according to the strict specifications of the registration function docstrings (e.g. that batch_size really is the mini-batch size, etc). So if "x" is some data needed in the registration function that isn't of the correct format, you can try something like the following: x_transformed = transform(x) lc.register_XXX(x_transformed) x = untransform(x_transformed) ...use x in rest of model... Note that without "x = untransform(x_transformed)" this often won't work since x_transformed won't be part of the model's forward graph, which is something K-FAC needs (especially for the "output" arguments of layers). """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from collections import defaultdict from collections import OrderedDict from contextlib import contextmanager from functools import partial import math # Dependency imports import six import tensorflow.compat.v1 as tf from tensorflow.python.util import nest from kfac.python.ops import fisher_blocks as fb from kfac.python.ops import loss_functions as lf from kfac.python.ops import utils from kfac.python.ops.tensormatch import graph_search # Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_KRONECKER_IN_DIAG_NAME = "kron_in_diag" APPROX_KRONECKER_OUT_DIAG_NAME = "kron_out_diag" APPROX_KRONECKER_BOTH_DIAG_NAME = "kron_both_diag" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" APPROX_KRONECKER_INDEP_NAME = "kron_indep" APPROX_KRONECKER_INDEP_IN_DIAG_NAME = "kron_indep_in_diag" APPROX_KRONECKER_INDEP_OUT_DIAG_NAME = "kron_indep_out_diag" APPROX_KRONECKER_INDEP_BOTH_DIAG_NAME = "kron_indep_both_diag" APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" APPROX_KRONECKER_SUA_NAME = "kron_sua" # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" _DEFAULT_LAYER_COLLECTION = None def get_default_layer_collection(): """Get default LayerCollection.""" if _DEFAULT_LAYER_COLLECTION is None: raise ValueError( "Attempted to retrieve default LayerCollection when none is set. Use " "LayerCollection.as_default().") return _DEFAULT_LAYER_COLLECTION def set_default_layer_collection(layer_collection): global _DEFAULT_LAYER_COLLECTION if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: raise ValueError("Default LayerCollection is already set.") _DEFAULT_LAYER_COLLECTION = layer_collection class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. Ensures that no Tensor is associated with two different keys. """ def __init__(self, *args, **kwargs): self._tensors = set() super(LayerParametersDict, self).__init__(*args, **kwargs) def __setitem__(self, key, value): key = self._canonicalize_key(key) tensors = key if isinstance(key, (tuple, list)) else (key,) key_collisions = self._tensors.intersection(tensors) if key_collisions: raise ValueError("Key(s) already present: {}".format(key_collisions)) self._tensors.update(tensors) super(LayerParametersDict, self).__setitem__(key, value) def __delitem__(self, key): key = self._canonicalize_key(key) self._tensors.remove(key) super(LayerParametersDict, self).__delitem__(key) def __getitem__(self, key): key = self._canonicalize_key(key) return super(LayerParametersDict, self).__getitem__(key) def __contains__(self, key): key = self._canonicalize_key(key) return super(LayerParametersDict, self).__contains__(key) def _canonicalize_key(self, key): if isinstance(key, (list, tuple)): return tuple(key) return key # TODO(b/68034464): add capability for LayerCollection to be "finalized" # and do this when it gets used by FisherEstimator / KfacOptimizer. class LayerCollection(object): """Registry of information about layers and losses. Note that you need to create a new one of these for each FisherEstimator or KfacOptimizer, as they can't be used more than once. The methods that you should interact with directly are: - register_XXX() - auto_register_layers() Additional control over the automatic registration process can be exerted by using the methods/properties: - set_default_XXX() and default_XXX - define_linked_parameters() and linked_parameters Attributes: fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer parameters (Tensors or tuples of Tensors) to FisherBlock instances. fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. losses: a list of LossFunction objects. The loss to be optimized is their sum. loss_colocation_ops: ops to colocate loss function evaluations with. These will typically be the inputs to the losses. """ def __init__(self, graph=None, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() self._linked_parameters = dict( ) # dict mapping sets of variables to optionally specified approximations. self._graph = graph or tf.get_default_graph() self._loss_dict = OrderedDict() # {str: LossFunction} self._subgraph = None self._default_generic_approximation = APPROX_DIAGONAL_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_conv2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( APPROX_KRONECKER_INDEP_NAME) self._default_conv2d_multi_approximation = ( APPROX_KRONECKER_INDEP_NAME) self._default_scale_and_shift_approximation = APPROX_FULL_NAME self.loss_colocation_ops = {} self.loss_coeffs = {} self._vars_to_uses = defaultdict(lambda: 0) self._finalized = False with tf.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name self._generic_approx_to_block_types = { APPROX_FULL_NAME: fb.NaiveFullFB, APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, } self._fully_connected_approx_to_block_types = { APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, APPROX_KRONECKER_IN_DIAG_NAME: partial(fb.FullyConnectedKFACBasicFB, diagonal_approx_for_input=True), APPROX_KRONECKER_OUT_DIAG_NAME: partial(fb.FullyConnectedKFACBasicFB, diagonal_approx_for_output=True), APPROX_KRONECKER_BOTH_DIAG_NAME: partial(fb.FullyConnectedKFACBasicFB, diagonal_approx_for_input=True, diagonal_approx_for_output=True), APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, } self._conv2d_approx_to_block_types = { APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, APPROX_KRONECKER_SUA_NAME: fb.ConvKFCBasicFB, } self._fully_connected_multi_approx_to_block_types = { APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, APPROX_KRONECKER_INDEP_IN_DIAG_NAME: partial(fb.FullyConnectedMultiIndepFB, diagonal_approx_for_input=True), APPROX_KRONECKER_INDEP_OUT_DIAG_NAME: partial(fb.FullyConnectedMultiIndepFB, diagonal_approx_for_output=True), APPROX_KRONECKER_INDEP_BOTH_DIAG_NAME: partial(fb.FullyConnectedMultiIndepFB, diagonal_approx_for_input=True, diagonal_approx_for_output=True), APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, option=1), APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, option=2) } self._conv2d_multi_approx_to_block_types = { APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB } self._scale_and_shift_approx_to_block_types = { APPROX_FULL_NAME: fb.ScaleAndShiftFullFB, APPROX_DIAGONAL_NAME: fb.ScaleAndShiftDiagonalFB } @property def losses(self): """Tuple of LossFunction objects registered with this LayerCollection.""" return nest.flatten(self.towers_by_loss) @property def towers_by_loss(self): """Tuple across losses of LossFunction objects registered to each tower.""" return tuple(tuple(lst) for lst in self._loss_dict.values()) @property def registered_variables(self): """A tuple of all of the variables currently registered.""" tuple_of_tuples = (utils.ensure_sequence(key) for key, block in six.iteritems(self.fisher_blocks)) flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) return flat_tuple @property def linked_parameters(self): """Groups of parameters with an optionally specified approximation. Linked parameters can be added using `define_linked_parameters`. If an approximation is specified, then this approximation will be used when registering a layer with exactly these parameters, unless an approximation is specified when calling the registration function. Returns: A `dict` mapping tuples of parameters to an optional string. """ return self._linked_parameters @property def default_generic_approximation(self): return self._default_generic_approximation def set_default_generic_approximation(self, value): if value not in self._generic_approx_to_block_types: raise ValueError( "{} is not a valid approximation for generic variables.".format( value)) self._default_generic_approximation = value @property def default_fully_connected_approximation(self): return self._default_fully_connected_approximation def set_default_fully_connected_approximation(self, value): if value not in self._fully_connected_approx_to_block_types: raise ValueError( "{} is not a valid approximation for fully connected layers.".format( value)) self._default_fully_connected_approximation = value @property def default_conv2d_approximation(self): return self._default_conv2d_approximation def set_default_conv2d_approximation(self, value): if value not in self._conv2d_approx_to_block_types: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) self._default_conv2d_approximation = value @property def default_fully_connected_multi_approximation(self): return self._default_fully_connected_multi_approximation def set_default_fully_connected_multi_approximation(self, value): if value not in self._fully_connected_multi_approx_to_block_types: raise ValueError("{} is not a valid approximation for a fully-connected " "multi layer.".format(value)) self._default_fully_connected_multi_approximation = value @property def default_conv2d_multi_approximation(self): return self._default_conv2d_multi_approximation def set_default_conv2d_multi_approximation(self, value): if value not in self._conv2d_multi_approx_to_block_types: raise ValueError("{} is not a valid approximation for a conv2d " "multi layer.".format(value)) self._default_conv2d_multi_approximation = value @property def default_scale_and_shift_approximation(self): return self._default_scale_and_shift_approximation def set_default_scale_and_shift_approximation(self, value): if value not in self._scale_and_shift_approx_to_block_types: raise ValueError("{} is not a valid approximation for a scale & shift " "layer.".format(value)) self._default_scale_and_shift_approximation = value def auto_register_layers(self, var_list=None, batch_size=None): """Registers remaining unregistered layers automatically using a scanner. Requires all function / distribution registrations to be performed (manually) first. Registrations will be performed using the default approximation mode for each type, as if the scanner were calling the user-level registration functions in this LayerCollection object (which it will be). These defaults can be overridden using the set_default_XXX_approximation methods for types of layers, or using the define_linked_parameters method for specific parameters. This function should only be called after any desired manual registrations are performed. For example, if you have a layer which isn't recognized properly by the scanner, or a layer which you want to register differently. Note that this function is an experimental convenience feature which won't work for every possible model architecture. Any layers/parameters that whose structure is not recognized will be registered as "generic", which is the worst curvature matrix approximation available in the system, and should be avoided if possible. See the docstring for register_layers in graph_search.py for more details. Args: var_list: A list of variables that the automatic registration should consider. If you have some trainable variables (i.e. those included in tf.trainable_variables()) that you don't want included you need to pass in this list. (Default: tf.trainable_variables()). batch_size: A `int` representing the batch size. Needs to specified if registering generic variables that don't match any layer patterns or if time/uses is folded. If the time/uses dimension is merged with batch then this is used to infer number of uses/time-steps. NOTE: In the replicated context this must be the per-replica batch size, and not the total batch size. """ if var_list is None: var_list = tf.trainable_variables() graph_search.register_layers(self, var_list, batch_size=batch_size) def finalize(self): if not self._finalized: self._create_subgraph() self._finalized = True else: raise ValueError("LayerCollection was finalized a second time, which " "indicates an error. Perhaps you used the same " "LayerCollection object in multiple " "optimizers/estimators, which is not allowed.") def _register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. Args: layer_key: A variable or tuple of variables. The key to check for in existing registrations and to register if valid. fisher_block: The associated `FisherBlock`. reuse: Method to use for inserting new `FisherBlock`s. One of True, False, or 'VARIABLE_SCOPE'. Raises: ValueError: If `layer_key` was already registered and reuse is `False`, if `layer_key` was registered with a different block type, or if `layer_key` shares any variables with but is not equal to a previously registered key. KeyError: If `reuse` is `True` but `layer_key` was not previously registered. Returns: The `FisherBlock` registered under `layer_key`. If `layer_key` was already registered, this will be the previously registered `FisherBlock`. """ if self._finalized: raise ValueError("You cannot register additional losses or layers after " "LayerCollection is finalized. Finalization happens " "after the estimator or optimizer object first uses " "the data in the LayerCollection. For example, when " "the minimize() method is called in " "PeriodicInvCovUpdateKfacOpt.") if reuse is VARIABLE_SCOPE: reuse = tf.get_variable_scope().reuse if reuse is True or (reuse is tf.AUTO_REUSE and layer_key in self.fisher_blocks): if layer_key not in self.fisher_blocks: raise ValueError( "reuse was True for attempted registration involving variables {}, " "but no previously registered layer was found for these. Perhaps " "reuse was set to True by mistake. One way this can happen is if " "reuse is set to True in the surrounding variable scope." "".format(layer_key)) result = self.fisher_blocks[layer_key] if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck raise ValueError( "Attempted to register FisherBlock of type %s when existing " "FisherBlock has type %s." % (type(fisher_block), type(result))) return result if reuse is False and layer_key in self.fisher_blocks: raise ValueError("FisherBlock for %s is already in LayerCollection." % (layer_key,)) # Insert fisher_block into self.fisher_blocks. if layer_key in self.fisher_blocks: raise ValueError("Duplicate registration: {}".format(layer_key)) # Raise an error if any variable in layer_key has been registered in any # other blocks. variable_to_block = { var: (params, block) for (params, block) in self.fisher_blocks.items() for var in utils.ensure_sequence(params) } for variable in utils.ensure_sequence(layer_key): if variable in variable_to_block: prev_key, prev_block = variable_to_block[variable] raise ValueError( "Attempted to register layer_key {} with block {}, but variable {}" " was already registered in key {} with block {}.".format( layer_key, fisher_block, variable, prev_key, prev_block)) self.fisher_blocks[layer_key] = fisher_block return fisher_block def _register_loss_function(self, loss, colocation_op, base_name, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a LossFunction object. Args: loss: The LossFunction object. colocation_op: The op to colocate the loss function's computations with. base_name: The name to derive a new unique name from is the name argument is None. name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a scalar. coefficient on the loss function (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, adds 'loss' as an additional tower for the existing loss function. Raises: ValueError: If reuse == True and name == None. ValueError: If reuse == True and seed != None. KeyError: If reuse == True and no existing LossFunction with 'name' found. KeyError: If reuse == False and existing LossFunction with 'name' found. """ if self._finalized: raise ValueError("You cannot register additional losses or layers after " "LayerCollection is finalized. Finalization happens " "after the estimator or optimizer object first uses " "the data in the LayerCollection. For example, when " "the minimize() method is called in " "PeriodicInvCovUpdateKfacOpt.") name = name or self._graph.unique_name(base_name) if reuse == VARIABLE_SCOPE: reuse = tf.get_variable_scope().reuse if reuse: if name is None: raise ValueError( "If reuse is enabled, loss function's name must be set.") loss_list = self._loss_dict.get(name, None) if loss_list is None: raise KeyError( "Unable to find loss function named {}. Register a new loss " "function with reuse=False.".format(name)) if self.loss_coeffs[loss_list[0]] != coeff: raise ValueError( "Reused loss function's coeff didn't match previous supplied " "value.") else: if name in self._loss_dict: raise KeyError( "Loss function named {} already exists. Set reuse=True to append " "another tower.".format(name)) loss_list = [] self._loss_dict[name] = loss_list loss_list.append(loss) self.loss_colocation_ops[loss] = colocation_op self.loss_coeffs[loss] = coeff def _get_use_count_map(self): """Returns a dict mapping variables to their number of registrations.""" return self._vars_to_uses def _add_uses(self, params, uses): """Register additional uses by params in the graph. Args: params: Variable or tuple of Variables. Parameters for a layer. uses: int or float. Number of additional uses for these parameters. """ params = params if isinstance(params, (tuple, list)) else (params,) for var in params: self._vars_to_uses[var] += uses def check_registration(self, variables): """Checks that all variable uses have been registered properly. Args: variables: List of variables. Raises: ValueError: If any registered variables are not included in the list. ValueError: If any variable in the list is not registered. ValueError: If any variable in the list is registered with the wrong number of "uses" in the subgraph recorded (vs the number of times that variable is actually used in the subgraph). """ # Note that overlapping parameters (i.e. those that share variables) will # be caught by layer_collection.LayerParametersDict during registration. reg_use_map = self._get_use_count_map() error_messages = [] for var in variables: total_uses = self.subgraph.variable_uses(var) reg_uses = reg_use_map[var] if reg_uses == 0: error_messages.append("Variable {} not registered.".format(var)) elif (not math.isinf(reg_uses)) and reg_uses != total_uses: error_messages.append( "Variable {} registered with wrong number of uses ({} uses " "registered vs {} uses found in sub-graph generated from " "registered losses).".format(var, reg_uses, total_uses)) num_get_vars = len(reg_use_map) if num_get_vars > len(variables): error_messages.append("{} registered variables were not included in list." .format(num_get_vars - len(variables))) if error_messages: error_string = "\n\t".join([ "Found the following errors with variable registration:" ] + error_messages) raise ValueError(error_string) def get_blocks(self): return tuple(self.fisher_blocks.values()) def get_factors(self): return tuple(self.fisher_factors.values()) @property def graph(self): return self._graph @property def subgraph(self): return self._subgraph def define_linked_parameters(self, params, approximation=None): """Identify a set of parameters that should be grouped together. Also allows the approximation type string to be set for the given parameter grouping. During automatic graph scanning (as done by the auto_register_layers method) any matches containing variables that have been identified as part of a linked group will be filtered out unless the match parameters are exactly equal to the ones specified in the linked group. Args: params: A variable, or a tuple or list of variables. The variables to be linked. approximation: Optional string specifying the type of approximation to use for these variables. If unspecified, this layer collection's default approximation for the layer type will be used. Raises: ValueError: If the parameters were already registered in a layer or identified as part of an incompatible group. """ params = frozenset(utils.ensure_sequence(params)) # Check if any of the variables in 'params' is already in # 'self.fisher_blocks.keys()'. for registered_params, fisher_block in self.fisher_blocks.items(): registered_params_set = set(utils.ensure_sequence(registered_params)) for variable in params: if (variable in registered_params_set and params != registered_params_set): raise ValueError( "Can't link parameters {}, variable {} was already registered in " "group {} with layer {}".format(params, variable, registered_params, fisher_block)) # Check if any of the variables in 'params' is already in # 'self.linked_parameters'. for variable in params: for other_linked_params in self.linked_parameters: if variable in other_linked_params: raise ValueError("Can't link parameters {}, variable {} was already " "linked in group {}.".format(params, variable, other_linked_params)) self._linked_parameters[params] = approximation def _create_subgraph(self): if not self.losses: raise ValueError("Must have at least one registered loss.") inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) self._subgraph = utils.SubGraph(inputs_to_losses) def eval_losses(self, target_mode="data", coeff_mode="regular"): """Returns evaluated losses (colocated with inputs to losses).""" evals = [] for loss in self.losses: with tf.colocate_with(self.loss_colocation_ops[loss]): if target_mode == "data": loss_value = loss.evaluate() elif target_mode == "sample": loss_value = loss.evaluate_on_sample() else: raise ValueError("target_mode must be in ['data', 'sample']") if coeff_mode == "regular": multiplier = self.loss_coeffs[loss] elif coeff_mode == "sqrt": multiplier = tf.sqrt(self.loss_coeffs[loss]) elif coeff_mode == "off": multiplier = 1.0 else: raise ValueError("coeff_mode must be in ['regular', 'sqrt', 'off']") multiplier = tf.cast(multiplier, dtype=loss_value.dtype) evals.append(multiplier * loss_value) return evals def total_loss(self, coeff_mode="regular"): return tf.add_n(self.eval_losses(target_mode="data", coeff_mode=coeff_mode)) def total_sampled_loss(self, coeff_mode="regular"): return tf.add_n(self.eval_losses(target_mode="sample", coeff_mode=coeff_mode)) def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" params_set = frozenset(utils.ensure_sequence(params)) if params_set in self.linked_parameters: return self.linked_parameters[params_set] else: return None def _get_block_type(self, params, approx, default, approx_to_type): if approx is None: approx = self._get_linked_approx(params) if approx is None: approx = default if approx not in approx_to_type: raise ValueError("Bad value {} for approx.".format(approx)) return approx_to_type[approx], approx def register_fully_connected(self, params, inputs, outputs, approx=None, dense_inputs=True, reuse=VARIABLE_SCOPE): """Registers a fully connected layer. Args: params: Variable or 2-tuple of variables corresponding to weight and bias parameters of this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: Tensor. Two formats are accepted. In most cases the Tensor is dense inputs, with shape [batch_size, input_size]. In some cases the Tensor is sparse inputs, with shape [batch_size]. A typical example of sparse inputs is the vocab indices into an embedding matrix. Sparse inputs will be converted to the dense format within KFAC. For sparse inputs, dense_inputs should be set to False. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. approx: str or None. If not None must be one of "kron", "kron_in_diag" (diagonal approximation for the input kronecker factor), "kron_out_diag" (diagonal approximation for the output kronecker factor), "kron_both_diag" or "diagonal". The Fisher approximation to use. If None the default value is used. (Default: None) dense_inputs: bool. True if inputs are dense inputs. (Default: True) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ block_type, approx = self._get_block_type( params, approx, self.default_fully_connected_approximation, self._fully_connected_approx_to_block_types) has_bias = isinstance(params, (tuple, list)) block = self._register_block( params, block_type(self, has_bias=has_bias), reuse=reuse) if not dense_inputs: inputs.one_hot_depth = int(params.shape[0]) block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) def register_conv1d(self, params, strides, padding, inputs, outputs, dilations=None, approx=None, reuse=VARIABLE_SCOPE, sub_sample_inputs=None, sub_sample_patches=None): """Registers a call to tf.nn.conv1d(). Args: params: Variablle or 2-tuple of variables corresponding to weight and bias parameters this layer. Weight matrix should have shape [kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. strides: List of 3 ints. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: Tensor of shape [batch_size, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, width, out_channels]. Output produced by layer. dilations: List of 3 ints. Dilations along each dimension. approx: str or None. If not None, must be "kron". The Fisher approximation to use. If None, the default value is used. (Default: None) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") sub_sample_inputs: `bool`. If True, then subsample the inputs from which the image patches are extracted. (Default: None) sub_sample_patches: `bool`, If `True` then subsample the extracted patches. (Default: None) Raises: KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ assert approx is None or approx == APPROX_KRONECKER_NAME block = self._register_block( params, fb.ConvKFCBasicFB( layer_collection=self, params=params, padding=padding, strides=strides, data_format="NWC", dilation_rate=dilations, extract_patches_fn="extract_convolution_patches", sub_sample_inputs=sub_sample_inputs, sub_sample_patches=sub_sample_patches, use_sua_approx_for_input_factor=False), reuse=reuse) block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) def register_conv2d(self, params, strides, padding, inputs, outputs, data_format=None, dilations=None, approx=None, reuse=VARIABLE_SCOPE, sub_sample_inputs=None, sub_sample_patches=None, patch_mask=None): """Registers a call to tf.nn.conv2d(). Args: params: Variable or 2-tuple of variables corresponding to weight and bias parameters of this layer. Weight matrix should have shape [kernel_height, kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. strides: List of 4 ints. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. Output produced by layer. data_format: str or None. Format of data. If None, this should default to 'NWHC'. (Default: None) dilations: List of 4 ints. Dilations along each dimension. approx: str or None. If not None must be one of "kron" or "diagonal". The Fisher approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") sub_sample_inputs: `bool`. If True, then subsample the inputs from which the image patches are extracted. (Default: None) sub_sample_patches: `bool`, If `True` then subsample the extracted patches. (Default: None) patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed. This can (and probably should) be used if the filter bank matrix is masked in a way that is homogenous across the output channels. (Other masking patterns have no direct support.) Currently only works with the approx="kron" or "diagonal". (Default: None) Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ assert data_format in [None, "NHWC"] # We don't support NCHW right now block_type, approx = self._get_block_type( params, approx, self.default_conv2d_approximation, self._conv2d_approx_to_block_types) # It feels bad to pass in configuration that has to do with the internal # implementation. And then we can't use the same constructor for both # anymore and are thus forced to use this ugly if-statement. # TODO(b/74793309): Clean this up? if approx == APPROX_KRONECKER_NAME: block = self._register_block( params, block_type( layer_collection=self, params=params, padding=padding, strides=strides, data_format=data_format, dilation_rate=dilations, extract_patches_fn="extract_image_patches", sub_sample_inputs=sub_sample_inputs, sub_sample_patches=sub_sample_patches, use_sua_approx_for_input_factor=False, patch_mask=patch_mask), reuse=reuse) elif approx == APPROX_DIAGONAL_NAME: assert strides[0] == strides[-1] == 1 block = self._register_block( params, block_type( layer_collection=self, params=params, padding=padding, strides=strides, dilations=dilations, data_format=data_format, patch_mask=patch_mask), reuse=reuse) elif approx == APPROX_KRONECKER_SUA_NAME: block = self._register_block( params, block_type( layer_collection=self, params=params, padding=padding, use_sua_approx_for_input_factor=True), reuse=reuse) else: raise NotImplementedError(approx) block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) def register_convolution(self, params, inputs, outputs, padding, strides=None, dilation_rate=None, data_format=None, approx=None, reuse=VARIABLE_SCOPE): """Register a call to tf.nn.convolution(). Unless you know what you are doing you should be using register_conv2d instead. Args: params: Variable or 2-tuple of variables corresponding to weight and bias parameters of this layer. Weight matrix should have shape [..filter_spatial_size.., in_channels, out_channels]. Bias should have shape [out_channels]. inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, ..output_spatial_size.., out_channels]. Output produced by layer. padding: string. see tf.nn.conv2d for valid values. strides: List of ints of length len(..input_spatial_size..). Strides for convolution kernel in spatial dimensions. dilation_rate: List of ints of length len(..input_spatial_size..). Dilations along spatial dimension. data_format: str or None. Format of data. approx: str or None. If not None, must be "kron". The Fisher approximation to use. If None, the default value is used. (Default: None) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ # TODO(b/74793309): Have this use _get_block_type like the other # registration functions? assert approx is None or approx == APPROX_KRONECKER_NAME block = self._register_block( params, fb.ConvKFCBasicFB( layer_collection=self, params=params, padding=padding, strides=strides, dilation_rate=dilation_rate, data_format=data_format), reuse=reuse) block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) def register_depthwise_conv2d(self, params, inputs, outputs, strides, padding, rate=None, data_format=None, approx=None, reuse=VARIABLE_SCOPE): """Register a call to tf.nn.depthwise_conv2d(). Note that this is an experimental feature that hasn't been experimentally validated or published on. Args: params: 4-D variable of shape [filter_height, filter_width, in_channels, channel_multiplier]. Convolutional filter. inputs: Tensor of shape [batch_size, input_height, input_width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, output_height, output_width, in_channels * channel_multiplier]. Output produced by depthwise conv2d. strides: List of ints of length 4. Strides along all dimensions. padding: string. see tf.nn.conv2d for valid values. rate: None or List of ints of length 2. Dilation rates in spatial dimensions. data_format: str or None. Format of data. approx: str or None. If not None must "diagonal". The Fisher approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ # TODO(b/74793309): Have this use _get_block_type like the other # registration functions? assert approx is None or approx == APPROX_DIAGONAL_NAME assert data_format in [None, "NHWC"] block = self._register_block( params, fb.DepthwiseConvDiagonalFB( layer_collection=self, params=params, strides=strides, padding=padding, rate=rate, data_format=data_format), reuse=reuse) block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) def register_separable_conv2d(self, depthwise_params, pointwise_params, inputs, depthwise_outputs, pointwise_outputs, strides, padding, rate=None, data_format=None, approx=None, reuse=VARIABLE_SCOPE): """Register a call to tf.nn.separable_conv2d(). Note: This requires access to intermediate outputs between depthwise and pointwise convolutions. Note that this is an experimental feature that hasn't been experimentally validated or published on. Args: depthwise_params: 4-D variable of shape [filter_height, filter_width, in_channels, channel_multiplier]. Filter for depthwise conv2d. pointwise_params: 4-D variable of shape [1, 1, in_channels * channel_multiplier, out_channels]. Filter for pointwise conv2d. inputs: Tensor of shape [batch_size, input_height, input_width, in_channels]. Inputs to layer. depthwise_outputs: Tensor of shape [batch_size, output_height, output_width, in_channels * channel_multiplier]. Output produced by depthwise conv2d. pointwise_outputs: Tensor of shape [batch_size, output_height, output_width, out_channels]. Output produced by pointwise conv2d. strides: List of ints of length 4. Strides for depthwise conv2d kernel in all dimensions. padding: string. see tf.nn.conv2d for valid values. rate: None or List of ints of length 2. Dilation rate of depthwise conv2d kernel in spatial dimensions. data_format: str or None. Format of data. approx: str or None. If not None must be one of "kron" or "diagonal". The Fisher approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ self.register_depthwise_conv2d( params=depthwise_params, inputs=inputs, outputs=depthwise_outputs, strides=strides, padding=padding, rate=rate, data_format=data_format, approx=APPROX_DIAGONAL_NAME, reuse=reuse) self.register_conv2d( params=pointwise_params, inputs=depthwise_outputs, outputs=pointwise_outputs, strides=[1, 1, 1, 1], padding="VALID", data_format=data_format, approx=approx, reuse=reuse) def register_generic(self, params, batch_size, approx=None, reuse=VARIABLE_SCOPE): """Registers parameters without assuming any structure. Note that this is an approximation of last resort and should be avoided if anything else will work. Args: params: Variable or tuple of variables corresponding to the parameters. If using "diagonal" approximation this must be a single variable. batch_size: 0-D Tensor. Size of the minibatch (for this tower). approx: str or None. It not None, must be one of "full" or "diagonal". The Fisher approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, this adds 'batch_size' to the total mini-batch size use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. ValueError: If approx == "diagonal" and params is a tuple. """ block_type, approx = self._get_block_type( params, approx, self.default_generic_approximation, self._generic_approx_to_block_types) if approx == APPROX_DIAGONAL_NAME and isinstance(params, (tuple, list)): raise ValueError("Params must be a Variable if using the diagonal " "approximation.") block = self._register_block(params, block_type(self, params), reuse=reuse) block.register_additional_tower(batch_size) self._add_uses(params, float("inf")) def register_fully_connected_multi(self, params, inputs, outputs, num_uses=None, approx=None, dense_inputs=True, reuse=VARIABLE_SCOPE): """Register fully connected layers with shared parameters. This can handle general fully-connected layers with shared parameters, but has specialized approximations to deal with the case where there is a meaningful linear order to the share instances (such as in an RNN). Note that padding is *not* supported. The arguments to this method cannot be zero-padded or anything of that sort. Args: params: Variable or 2-tuple of variables corresponding to weight and bias of this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: A list of Tensors or a single Tensor. Inputs to this layer. If a list of Tensors, the list indexes each use in the model (which might correspond to a "time-step" in an RNN). Each Tensor in the list has leading dimension batch_size. If a single Tensor, should have shape [num_uses, batch_size, input_size] or be a reshape of such a tensor to shape [num_uses, batch_size, input_size]. Similar to register_fully_connected(), two formats of tensors are accepted: dense inputs and sparse inputs. In most cases the Tensors are dense inputs, with shape [batch_size, input_size] (if a list) or [num_uses, batch_size, input_size] (if a single Tensor) or [num_uses * batch_size, input_size] (if a single Tensor). In some cases the Tensors are sparse inputs, with shape [batch_size] (if a list) or or [num_uses, batch_size] (if a single Tensor) or [num_uses * batch_size] (if a single Tensor). A typical example of sparse inputs is the vocab indices into an embedding matrix. For sparse inputs, the argument 'dense_inputs' should be set to False. outputs: A list of Tensors, the same length as 'inputs', each of shape [batch_size, output_size]. Outputs produced by layer. The list indexes each use in the model (which might correspond to a "time-step" in an RNN). Needs to correspond with the order used in 'inputs'. OR, can be a single Tensor of shape [num_uses * batch_size, output_size], which is a reshaped version of a Tensor of shape [num_uses, batch_size, output_size]. num_uses: int or None. The number uses/time-steps in the model where the layer appears. Only needed if both inputs and outputs are given in the single Tensor format. (Default: None) approx: str or None. If not None, must be one of "kron_indep", "kron_indep_in_diag" (diagonal approximation for the input kronecker factor), "kron_indep_out_diag" (diagonal approximation for the output kronecker factor), "kron_indep_both_diag", "kron_series_1" or "kron_series_2". The Fisher approximation to use. If None the default value is used (which starts out as "kron_indep"). (Default: None) dense_inputs: bool. True if inputs are dense inputs. (Default: True) reuse: bool or str. If True, this adds inputs and outputs as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the word 'use' here has a completely different meaning to "use in the model" as it pertains to the 'inputs', 'outputs', and 'num_uses' arguments.) (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. """ block_type, approx = self._get_block_type( params, approx, self.default_fully_connected_multi_approximation, self._fully_connected_multi_approx_to_block_types) # TODO(b/70283649): something along the lines of find_canonical_output # should be added back in here (and for the other block types, arguably). has_bias = isinstance(params, (tuple, list)) block = self._register_block( params, block_type(self, has_bias=has_bias, num_uses=num_uses), reuse=reuse) if isinstance(inputs, (tuple, list)): inputs = tuple(inputs) if isinstance(outputs, (tuple, list)): outputs = tuple(outputs) if not dense_inputs: if isinstance(inputs, (tuple, list)): for input in inputs: input.one_hot_depth = int(params.shape[0]) else: inputs.one_hot_depth = int(params.shape[0]) block.register_additional_tower(inputs, outputs) if isinstance(inputs, (tuple, list)): assert len(inputs) == len(outputs) self._add_uses(params, len(inputs)) else: self._add_uses(params, 1) def register_conv2d_multi(self, params, strides, padding, inputs, outputs, num_uses=None, data_format=None, dilations=None, approx=None, reuse=VARIABLE_SCOPE): """Registers convolutional layers with shared parameters. Note that padding is *not* supported. The arguments to this method cannot be zero-padded or anything of that sort. Args: params: Variable or 2-tuple of variables corresponding to weight and bias of this layer. Weight matrix should have shape [kernel_height, kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. strides: 1-D Tensor of length 4. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: A list of Tensors, each of shape [batch_size, height, width, in_channels]. Inputs to layer. The list indexes each use in the model (which might correspond to a "time-step" in an RNN). OR, can be single Tensor, of shape [num_uses * batch_size, height, width, in_channels], which is a reshaped version of a Tensor of shape [num_uses, batch_size, height, width, in_channels]. outputs: A list of Tensors, each of shape [batch_size, height, width, out_channels]. Output produced by layer. The list indexes each use in the model (which might correspond to a "time-step" in an RNN). Needs to correspond with the order used in 'inputs'. OR, can be a single Tensor, of shape [num_uses * batch_size, height, width, out_channels], which is a reshaped version of a Tensor of shape [num_uses, batch_size, height, width, out_channels]. num_uses: int or None. The number uses/time-steps in the model where the layer appears. Only needed if both inputs and outputs are given in the single Tensor format. (Default: None) data_format: str or None. Format of data. dilations: List of 4 ints. Dilations along each dimension. approx: str or None. If not None must be "kron_indep". The Fisher approximation to use. If None the default value is used (which starts out as "kron_indep"). (Default: None) reuse: bool or str. If True, this adds inputs and outputs as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the word 'use' here has a completely different meaning to "use in the model" as it pertains to the 'inputs', 'outputs', and 'num_uses' arguments.) (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ assert data_format in [None, "NHWC"] # We don't support NCHW right now block_type, approx = self._get_block_type( params, approx, self.default_conv2d_multi_approximation, self._conv2d_multi_approx_to_block_types) block = self._register_block( params, block_type( layer_collection=self, params=params, padding=padding, strides=strides, data_format=data_format, dilation_rate=dilations, extract_patches_fn="extract_image_patches", num_uses=num_uses), reuse=reuse) if isinstance(inputs, (tuple, list)): inputs = tuple(inputs) if isinstance(outputs, (tuple, list)): outputs = tuple(outputs) block.register_additional_tower(inputs, outputs) if isinstance(inputs, (tuple, list)): assert len(inputs) == len(outputs) self._add_uses(params, len(inputs)) else: self._add_uses(params, 1) def register_scale_and_shift(self, params, inputs, outputs, approx=None, reuse=VARIABLE_SCOPE): """Registers a scale and shift operation. A scale and shift operation is a parameterized operation of the form outputs = scale * inputs + shift , where scale and shift are variables that broadcast to the shape of inputs. outputs and inputs must have batch dimension. scale and shift can have a corresponding dimension (although they don't need to), but it must be 1. These kinds of operations appear frequently in various "normalization" layers like Layer Normalization. Batch Normalization layers should still be registered as "generic". Note that this is an experimental feature that hasn't been experimentally validated or published on. Args: params: Variable or 2-tuple of Variables corresponding to the scale and possibly shift parameters (scale must be first). Note that if these have a dimension corresponding to the batch dimension of 'inputs' and 'outputs', that dimension must be 1. inputs: Tensor of shape [batch_size, ...]. Input tensor that is multiplied by the scale the scale tensor. outputs: Tensor of shape [batch_size, ...]. Final output produced by the scale and shift. Must have the same shape as 'inputs'. approx: str or None. If not None must be one of "full" or "diagonal". The Fisher approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an additional mini-batch/tower of data to use when estimating the Fisher block for this layer (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ # TODO(jamesmartens): Consider replacing some of the logic below with calls # to tf.broadcast_static_shape. if isinstance(params, (tuple, list)): scale = params[0] shift = params[1] has_shift = True start_dim = len(outputs.shape) - len(shift.shape) if start_dim < 0: raise ValueError("Rank of shift cannot exceed that of outputs.") if start_dim == 0 and shift.shape[0] != 1: raise ValueError("If shift has a batch dimension its value must be 1.") broadcast_dims_shift = list(range(1, start_dim)) for i in range(max(start_dim, 1), len(outputs.shape)): if shift.shape[i - start_dim] < outputs.shape[i]: if shift.shape[i - start_dim] == 1: broadcast_dims_shift.append(i) else: raise ValueError("It appears that shift param and output have " "incompatible shapes. This is probably due to " "misspecified arguments.") elif shift.shape[i - start_dim] > outputs.shape[i]: raise ValueError("It appears that shift param and output have " "incompatible shapes. This is probably due to " "misspecified arguments.") broadcast_dims_shift = tuple(broadcast_dims_shift) else: has_shift = False scale = params broadcast_dims_shift = None start_dim = len(inputs.shape) - len(scale.shape) if start_dim < 0: raise ValueError("Rank of scale cannot exceed that of inputs.") if start_dim == 0 and scale.shape[0] != 1: raise ValueError("If scale has a batch dimension its value must be 1.") broadcast_dims_scale = list(range(1, start_dim)) for i in range(max(start_dim, 1), len(inputs.shape)): if scale.shape[i - start_dim] < inputs.shape[i]: if scale.shape[i - start_dim] == 1: broadcast_dims_scale.append(i) else: raise ValueError("It appears that scale param and input have " "incompatible shapes. This is probably due to " "misspecified arguments.") broadcast_dims_scale = tuple(broadcast_dims_scale) block_type, approx = self._get_block_type( params, approx, self.default_scale_and_shift_approximation, self._scale_and_shift_approx_to_block_types) block = self._register_block(params, block_type( self, broadcast_dims_scale, broadcast_dims_shift=broadcast_dims_shift, has_shift=has_shift), reuse=reuse) block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) def register_categorical_predictive_distribution(self, logits, seed=None, targets=None, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a categorical predictive distribution. Corresponds to losses computed using tf.nn.sparse_softmax_cross_entropy_with_logits. Note that this is distinct from register_multi_bernoulli_predictive_distribution and should not be confused with it. Args: logits: The logits of the distribution (i.e. its parameters). The first dimension must be the batch size. seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. This is NOT equivalent to changing the temperature of the distribution since we don't renormalize the log prob in the objective function. (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, this adds 'logits' as an additional mini-batch/tower of inputs to the loss-function/predictive distribution (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, seed=seed) self._register_loss_function(loss, logits, "categorical_predictive_distribution", name=name, coeff=coeff, reuse=reuse) def register_softmax_cross_entropy_loss(self, logits, seed=None, targets=None, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a softmax cross-entropy loss function. Corresponds to losses computed using tf.nn.sparse_softmax_cross_entropy_with_logits. Note that this is distinct from register_sigmoid_cross_entropy_loss and should not be confused with it. It is similar to register_categorical_predictive_distribution but without the explicit probabilistic interpretation. It behaves identically for now. Args: logits: The logits of the distribution (i.e. its parameters). The first dimension must be the batch size. seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the loss function by. (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, this adds 'logits' as an additional mini-batch/tower of inputs to the loss-function/predictive distribution (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, seed=seed) self._register_loss_function(loss, logits, "sparse_softmax_cross_entropy_loss", name=name, coeff=coeff, reuse=reuse) def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a normal predictive distribution. This corresponds to a squared error loss of the form coeff/(2*var) * ||target - mean||^2 Args: mean: A tensor defining the mean vector of the distribution. The first dimension must be the batch size. var: float. The variance of the distribution. Note that the default value of 0.5 corresponds to a standard squared error loss coeff*||target - prediction||^2. If you want your squared error loss to be of the form 0.5*coeff*||target - prediction||^2 you should use var=1.0. (Default: 0.5) seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. In general this is NOT equivalent to changing the temperature of the distribution, but in the case of normal distributions it may be. (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, this adds 'mean' and 'var' as an additional mini-batch/tower of inputs to the loss-function/predictive distribution (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, seed=seed) self._register_loss_function(loss, mean, "normal_predictive_distribution", name=name, coeff=coeff, reuse=reuse) def register_squared_error_loss(self, prediction, seed=None, targets=None, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a squared error loss function. This assumes the squared error loss of the form ||target - prediction||^2, averaged across the mini-batch. If your loss uses a coefficient of 0.5 (tf.nn.l2_loss does this, for example) you need to set the "coeff" argument to reflect this. Args: prediction: The prediction made by the network (i.e. its output). The first dimension must be the batch size. seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the loss function by. (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, this adds 'prediction' as an additional mini-batch/tower of inputs to the loss-function/predictive distribution (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.NormalMeanNegativeLogProbLoss(prediction, var=0.5, targets=targets, seed=seed) self._register_loss_function(loss, prediction, "squared_error_loss", name=name, coeff=coeff, reuse=reuse) def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, targets=None, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a multi-Bernoulli predictive distribution. Corresponds to losses computed using tf.nn.sigmoid_cross_entropy_with_logits. Note that this is distinct from register_categorical_predictive_distribution and should not be confused with it. Args: logits: The logits of the distribution (i.e. its parameters). The first dimension must be the batch size. seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. This is NOT equivalent to changing the temperature of the distribution since we don't renormalize the log prob in the objective function. (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, this adds 'logits' as an additional mini-batch/tower of inputs to the loss-function/predictive distribution (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, seed=seed) self._register_loss_function(loss, logits, "multi_bernoulli_predictive_distribution", name=name, coeff=coeff, reuse=reuse) def register_sigmoid_cross_entropy_loss(self, logits, seed=None, targets=None, name=None, coeff=1.0, reuse=VARIABLE_SCOPE): """Registers a sigmoid cross-entropy loss function. Corresponds to losses computed using tf.nn.sigmoid_cross_entropy_with_logits. Note that this is distinct from register_softmax_cross_entropy_loss and should not be confused with it. It is similar to register_multi_bernoulli_predictive_distribution but without the explicit probabilistic interpretation. It behaves identically for now. Args: logits: The logits tensor. The first dimension must be the batch size. seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) coeff: (OPTIONAL) a float or TF scalar. A coefficient to multiply the loss function by. (Default: 1.0) reuse: (OPTIONAL) bool or str. If True, this adds 'logits' as an additional mini-batch/tower of inputs to the loss-function/predictive distribution (which must have already been registered). If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, seed=seed) self._register_loss_function(loss, logits, "sigmoid_cross_entropy_loss", name=name, coeff=coeff, reuse=reuse) def make_or_get_factor(self, cls, args): """Insert 'cls(args)' into 'self.fisher_factors' if not already present. Wraps constructor in 'tf.variable_scope()' to ensure variables constructed in 'cls.__init__' are placed under this LayerCollection's scope. Args: cls: Class that implements FisherFactor. args: Tuple of arguments to pass into 'cls's constructor. Must be hashable. Returns: Instance of 'cls' found in self.fisher_factors. """ # TODO(b/123190346): Should probably change the args list to be keyworded # instead of positional. Note that this would require making changes in # each FisherBlock's call to make_or_get_factor. try: hash(args) except TypeError: raise TypeError( ("Unable to use (cls, args) = ({}, {}) as a key in " "LayerCollection.fisher_factors. The pair cannot be hashed.").format( cls, args)) key = cls, args if key not in self.fisher_factors: with tf.variable_scope(self._var_scope): self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] @contextmanager def as_default(self): """Sets this LayerCollection as the default.""" set_default_layer_collection(self) yield set_default_layer_collection(None) ================================================ FILE: kfac/python/ops/linear_operator.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Extra functionality we need for LinearOperators.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf from kfac.python.ops import utils linalg = tf.linalg class LinearOperatorExtras(object): # pylint: disable=missing-docstring def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): # pylint: disable=missing-docstring with self._name_scope(name): if isinstance(x, tf.IndexedSlices): return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) x = tf.convert_to_tensor(x, name="x") self._check_input_dtype(x) self_dim = -2 if adjoint else -1 arg_dim = -1 if adjoint_arg else -2 tf.TensorShape(self.shape[self_dim]).assert_is_compatible_with( x.get_shape()[arg_dim]) return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): # pylint: disable=missing-docstring with self._name_scope(name): if isinstance(x, tf.IndexedSlices): return self._matmul_right_sparse( x, adjoint=adjoint, adjoint_arg=adjoint_arg) x = tf.convert_to_tensor(x, name="x") self._check_input_dtype(x) self_dim = -1 if adjoint else -2 arg_dim = -2 if adjoint_arg else -1 tf.TensorShape(self.shape[self_dim]).assert_is_compatible_with( x.get_shape()[arg_dim]) return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) class LinearOperatorFullMatrix(LinearOperatorExtras, # pylint: disable=missing-docstring linalg.LinearOperatorFullMatrix): def _matmul_right(self, x, adjoint=False, adjoint_arg=False): return linalg.matmul( x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): raise NotImplementedError def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): assert not adjoint and not adjoint_arg return utils.matmul_sparse_dense(x, self._matrix) class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring linalg.LinearOperatorDiag): def _matmul_right(self, x, adjoint=False, adjoint_arg=False): diag_mat = tf.conj(self._diag) if adjoint else self._diag x = linalg.adjoint(x) if adjoint_arg else x return diag_mat * x def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): diag_mat = tf.conj(self._diag) if adjoint else self._diag assert not adjoint_arg return utils.matmul_diag_sparse(diag_mat, x) def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): raise NotImplementedError ================================================ FILE: kfac/python/ops/loss_functions.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Loss functions to be used by LayerCollection.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc # Dependency imports import six import tensorflow.compat.v1 as tf import tensorflow_probability as tfp @six.add_metaclass(abc.ABCMeta) class LossFunction(object): """Abstract base class for loss functions. Note that unlike typical loss functions used in neural networks these are summed and not averaged across cases in the batch, since this is what the users of this class (FisherEstimator and MatrixVectorProductComputer) will be expecting. The implication of this is that you will may want to normalize things like Fisher-vector products by the batch size when you use this class. It depends on the use case. """ @abc.abstractproperty def targets(self): """The targets being predicted by the model. Returns: None or Tensor of appropriate shape for calling self._evaluate() on. """ pass @abc.abstractproperty def inputs(self): """The inputs to the loss function (excluding the targets).""" pass def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: # We treat the targets as "constant". It's only the inputs that get # "back-propped" through. return self._evaluate(tf.stop_gradient(self.targets)) else: raise Exception("Cannot evaluate losses with unspecified targets.") @abc.abstractmethod def _evaluate(self, targets): """Evaluates the negative log probability of the targets. Args: targets: Tensor that distribution can calculate log_prob() of. Returns: negative log probability of each target, summed across all targets. """ pass @abc.abstractmethod def multiply_ggn(self, vector): """Right-multiply a vector by the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by the GGN. Will be of the same shape(s) as the 'inputs' property. """ pass @abc.abstractmethod def multiply_ggn_factor(self, vector): """Right-multiply a vector by a factor B of the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases. Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be of the shape given by the 'ggn_factor_inner_shape' property. Returns: The vector right-multiplied by B. Will be of the same shape(s) as the 'inputs' property. """ pass @abc.abstractmethod def multiply_ggn_factor_transpose(self, vector): """Right-multiply a vector by the transpose of a factor B of the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases. Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by B^T. Will be of the shape given by the 'ggn_factor_inner_shape' property. """ pass @abc.abstractmethod def multiply_ggn_factor_replicated_one_hot(self, index): """Right-multiply a replicated-one-hot vector by a factor B of the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases. A 'replicated-one-hot' vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere. Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class. Args: index: A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the 'ggn_factor_inner_shape' tensor minus one. Returns: The vector right-multiplied by B^T. Will be of the same shape(s) as the 'inputs' property. """ pass @abc.abstractproperty def ggn_factor_inner_shape(self): """The shape of the tensor returned by multiply_ggn_factor.""" pass @abc.abstractproperty def ggn_factor_inner_static_shape(self): """Static version of ggn_factor_inner_shape.""" pass @property def dtype(self): if isinstance(self.inputs, (list, tuple)): return self.inputs[0].dtype return self.inputs.dtype @six.add_metaclass(abc.ABCMeta) class NegativeLogProbLoss(LossFunction): """Abstract base class for loss functions that are negative log probs.""" def __init__(self, seed=None): self._default_seed = seed super(NegativeLogProbLoss, self).__init__() @property def inputs(self): return self.params @abc.abstractproperty def params(self): """Parameters to the underlying distribution.""" pass @abc.abstractmethod def multiply_fisher(self, vector): """Right-multiply a vector by the Fisher. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by the Fisher. Will be of the same shape(s) as the 'inputs' property. """ pass @abc.abstractmethod def multiply_fisher_factor(self, vector): """Right-multiply a vector by a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be of the shape given by the 'fisher_factor_inner_shape' property. Returns: The vector right-multiplied by B. Will be of the same shape(s) as the 'inputs' property. """ pass @abc.abstractmethod def multiply_fisher_factor_transpose(self, vector): """Right-multiply a vector by the transpose of a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by B^T. Will be of the shape given by the 'fisher_factor_inner_shape' property. """ pass @abc.abstractmethod def multiply_fisher_factor_replicated_one_hot(self, index): """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. A 'replicated-one-hot' vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere. Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, but will agree with the one used in the other methods of this class. Args: index: A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the 'fisher_factor_inner_shape' tensor minus one. Returns: The vector right-multiplied by B. Will be of the same shape(s) as the 'inputs' property. """ pass @abc.abstractproperty def fisher_factor_inner_shape(self): """The shape of the tensor returned by multiply_fisher_factor.""" pass @abc.abstractproperty def fisher_factor_inner_static_shape(self): """Static version of fisher_factor_inner_shape.""" pass @abc.abstractmethod def sample(self, seed): """Sample 'targets' from the underlying distribution.""" pass def evaluate_on_sample(self, seed=None): """Evaluates the log probability on a random sample. Args: seed: int or None. Random seed for this draw from the distribution. Returns: Log probability of sampled targets, summed across examples. """ if seed is None: seed = self._default_seed # We treat the targets as "constant". It's only the inputs that get # "back-propped" through. return self._evaluate(tf.stop_gradient(self.sample(seed))) class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): """Base class for neg log prob losses whose inputs are 'natural' parameters. We will take the GGN of the loss to be the Fisher associated with the distribution, which also happens to be equal to the Hessian for this class of loss functions. See here: https://arxiv.org/abs/1412.1193 'Natural parameters' are defined for exponential-family models. See for example: https://en.wikipedia.org/wiki/Exponential_family """ def multiply_ggn(self, vector): return self.multiply_fisher(vector) def multiply_ggn_factor(self, vector): return self.multiply_fisher_factor(vector) def multiply_ggn_factor_transpose(self, vector): return self.multiply_fisher_factor_transpose(vector) def multiply_ggn_factor_replicated_one_hot(self, index): return self.multiply_fisher_factor_replicated_one_hot(index) @property def ggn_factor_inner_shape(self): return self.fisher_factor_inner_shape @property def ggn_factor_inner_static_shape(self): return self.fisher_factor_inner_shape class DistributionNegativeLogProbLoss(NegativeLogProbLoss): """Base class for neg log prob losses that use the TF Distribution classes.""" def __init__(self, seed=None): super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) @abc.abstractproperty def dist(self): """The underlying tfp.distributions.Distribution.""" pass def _evaluate(self, targets): return -tf.reduce_sum(self.dist.log_prob(targets)) def sample(self, seed): return self.dist.sample(seed=seed) class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, NaturalParamsNegativeLogProbLoss): """Neg log prob loss for a normal distribution parameterized by a mean vector. Note that the covariance is treated as a constant 'var' times the identity. Also note that the Fisher for such a normal distribution with respect the mean parameter is given by: F = (1/var) * I See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. """ def __init__(self, mean, var=0.5, targets=None, seed=None): assert isinstance(var, float) # variance must be a constant float self._mean = mean self._var = var self._targets = targets super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) @property def targets(self): return self._targets @property def dist(self): return tfp.distributions.Normal(loc=self._mean, scale=tf.sqrt(self._var)) @property def params(self): return self._mean def multiply_fisher(self, vector): return (1. / self._var) * vector def multiply_fisher_factor(self, vector): return self._var**-0.5 * vector def multiply_fisher_factor_transpose(self, vector): return self.multiply_fisher_factor(vector) # it's symmetric in this case def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) ones_slice = tf.expand_dims( tf.ones(tf.shape(self._mean)[:1], dtype=self._mean.dtype), axis=-1) output_slice = self._var**-0.5 * ones_slice return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), index[0]) @property def fisher_factor_inner_shape(self): return tf.shape(self._mean) @property def fisher_factor_inner_static_shape(self): return self._mean.shape class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): """Negative log prob loss for a normal distribution with mean and variance. This class parameterizes a multivariate normal distribution with n independent dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not assume the variance is held constant. The Fisher Information for n = 1 is given by, F = [[1 / variance, 0], [ 0, 0.5 / variance^2]] where the parameters of the distribution are concatenated into a single vector as [mean, variance]. For n > 1, the mean parameter vector is concatenated with the variance parameter vector. See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. """ def __init__(self, mean, variance, targets=None, seed=None): assert len(mean.shape) == 2, "Expect 2D mean tensor." assert len(variance.shape) == 2, "Expect 2D variance tensor." self._mean = mean self._variance = variance self._targets = targets super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) @property def targets(self): return self._targets @property def dist(self): return tfp.distributions.Normal( loc=self._mean, scale=tf.sqrt(self._variance)) @property def params(self): return self._mean, self._variance def _concat(self, mean, variance): return tf.concat([mean, variance], axis=-1) def _split(self, params): return tf.split(params, 2, axis=-1) @property def _fisher_mean(self): return 1. / self._variance @property def _fisher_mean_factor(self): return 1. / tf.sqrt(self._variance) @property def _fisher_var(self): return 1. / (2 * tf.square(self._variance)) @property def _fisher_var_factor(self): return 1. / (tf.sqrt(2.) * self._variance) def multiply_fisher(self, vecs): mean_vec, var_vec = vecs return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) def multiply_fisher_factor(self, vecs): mean_vec, var_vec = self._split(vecs) return (self._fisher_mean_factor * mean_vec, self._fisher_var_factor * var_vec) def multiply_fisher_factor_transpose(self, vecs): mean_vec, var_vec = vecs return self._concat(self._fisher_mean_factor * mean_vec, self._fisher_var_factor * var_vec) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) index = index[0] if index < int(self._mean.shape[-1]): # Index corresponds to mean parameter. mean_slice = self._fisher_mean_factor[:, index] mean_slice = tf.expand_dims(mean_slice, axis=-1) mean_output = insert_slice_in_zeros(mean_slice, 1, int( self._mean.shape[1]), index) var_output = tf.zeros_like(mean_output) else: index -= int(self._mean.shape[-1]) # Index corresponds to variance parameter. var_slice = self._fisher_var_factor[:, index] var_slice = tf.expand_dims(var_slice, axis=-1) var_output = insert_slice_in_zeros(var_slice, 1, int(self._variance.shape[1]), index) mean_output = tf.zeros_like(var_output) return mean_output, var_output @property def fisher_factor_inner_shape(self): return tf.concat( [tf.shape(self._mean)[:-1], 2 * tf.shape(self._mean)[-1:]], axis=0) @property def fisher_factor_inner_static_shape(self): shape = self._mean.shape.as_list() return tf.TensorShape(shape[-1:] + [2 * shape[-1]]) def multiply_ggn(self, vector): raise NotImplementedError() def multiply_ggn_factor(self, vector): raise NotImplementedError() def multiply_ggn_factor_transpose(self, vector): raise NotImplementedError() def multiply_ggn_factor_replicated_one_hot(self, index): raise NotImplementedError() @property def ggn_factor_inner_shape(self): raise NotImplementedError() @property def ggn_factor_inner_static_shape(self): raise NotImplementedError() class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, NaturalParamsNegativeLogProbLoss): """Neg log prob loss for a categorical distribution parameterized by logits. Note that the Fisher (for a single case) of a categorical distribution, with respect to the natural parameters (i.e. the logits), is given by: F = diag(p) - p*p^T where p = softmax(logits). F can be factorized as F = B * B^T where B = diag(q) - p*q^T where q is the entry-wise square root of p. This is easy to verify using the fact that q^T*q = 1. """ def __init__(self, logits, targets=None, seed=None): """Instantiates a CategoricalLogitsNegativeLogProbLoss. Args: logits: Tensor of shape [batch_size, output_size]. Parameters for underlying distribution. targets: None or Tensor of shape [batch_size]. Each elements contains an index in [0, output_size). seed: int or None. Default random seed when sampling. """ self._logits = logits self._targets = targets super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) @property def targets(self): return self._targets @property def dist(self): return tfp.distributions.Categorical(logits=self._logits) @property def _probs(self): return self.dist.probs_parameter() @property def _sqrt_probs(self): return tf.sqrt(self._probs) @property def params(self): return self._logits def multiply_fisher(self, vector): probs = self._probs return vector * probs - probs * tf.reduce_sum( vector * probs, axis=-1, keepdims=True) def multiply_fisher_factor(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - probs * tf.reduce_sum( sqrt_probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_transpose(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - sqrt_probs * tf.reduce_sum( probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) probs = self._probs sqrt_probs = self._sqrt_probs sqrt_probs_slice = tf.expand_dims(sqrt_probs[:, index[0]], -1) padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, int(sqrt_probs.shape[1]), index[0]) return padded_slice - probs * sqrt_probs_slice @property def fisher_factor_inner_shape(self): return tf.shape(self._logits) @property def fisher_factor_inner_static_shape(self): return self._logits.shape class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, NaturalParamsNegativeLogProbLoss): """Neg log prob loss for multiple Bernoulli distributions param'd by logits. Represents N independent Bernoulli distributions where N = len(logits). Its Fisher Information matrix is given by, F = diag(p * (1-p)) p = sigmoid(logits) As F is diagonal with positive entries, its factor B is, B = diag(sqrt(p * (1-p))) """ def __init__(self, logits, targets=None, seed=None): self._logits = logits self._targets = targets super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) @property def targets(self): return self._targets @property def dist(self): return tfp.distributions.Bernoulli(logits=self._logits) @property def _probs(self): return self.dist.probs_parameter() @property def params(self): return self._logits def multiply_fisher(self, vector): return self._probs * (1 - self._probs) * vector def multiply_fisher_factor(self, vector): return tf.sqrt(self._probs * (1 - self._probs)) * vector def multiply_fisher_factor_transpose(self, vector): return self.multiply_fisher_factor(vector) # it's symmetric in this case def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) probs_slice = tf.expand_dims(self._probs[:, index[0]], -1) output_slice = tf.sqrt(probs_slice * (1 - probs_slice)) return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), index[0]) @property def fisher_factor_inner_shape(self): return tf.shape(self._logits) @property def fisher_factor_inner_static_shape(self): return self._logits.shape def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): """Inserts slice into a larger tensor of zeros. Forms a new tensor which is the same shape as slice_to_insert, except that the dimension given by 'dim' is expanded to the size given by 'dim_size'. 'position' determines the position (index) at which to insert the slice within that dimension. Assumes slice_to_insert.shape[dim] = 1. Args: slice_to_insert: The slice to insert. dim: The dimension which to expand with zeros. dim_size: The new size of the 'dim' dimension. position: The position of 'slice_to_insert' in the new tensor. Returns: The new tensor. Raises: ValueError: If the slice's shape at the given dim is not 1. """ slice_shape = slice_to_insert.shape if slice_shape[dim] != 1: raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " "was {}".format(dim, slice_to_insert.shape[dim])) before = [0] * int(len(slice_shape)) after = before[:] before[dim] = position after[dim] = dim_size - position - 1 return tf.pad(slice_to_insert, list(zip(before, after))) class OnehotCategoricalLogitsNegativeLogProbLoss( CategoricalLogitsNegativeLogProbLoss): """Neg log prob loss for a categorical distribution with onehot targets. Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying distribution is OneHotCategorical as opposed to Categorical. """ @property def dist(self): return tfp.distributions.OneHotCategorical(logits=self._logits) ================================================ FILE: kfac/python/ops/op_queue.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helper for choosing which op to run next in a distributed setting.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow.compat.v1 as tf class OpQueue(object): """Class for choosing which Op to run next. Constructs an infinitely repeating sequence of Ops in shuffled order. In K-FAC, this can be used to distribute inverse update operations among workers. """ def __init__(self, ops, seed=None): """Initializes an OpQueue. Args: ops: list of TensorFlow Ops. Ops to be selected from. All workers must initialize with the same set of ops. seed: int or None. Random seed used when shuffling order of ops. """ self._ops_by_name = {op.name: op for op in ops} # Construct a (shuffled) Dataset with Op names. op_names = tf.convert_to_tensor(list(sorted(op.name for op in ops))) op_names_dataset = ( tf.data.Dataset.from_tensor_slices(op_names).shuffle( len(ops), seed=seed).repeat()) self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() @property def ops(self): """Ops this OpQueue can return in next_op().""" return self._ops_by_name.values() def next_op(self, sess): """Chooses which op to run next. Note: This call will make a call to sess.run(). Args: sess: tf.Session. Returns: Next Op chosen from 'ops'. """ # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') # returns a str. next_op_name = sess.run(self._next_op_name).decode('ascii') return self._ops_by_name[next_op_name] ================================================ FILE: kfac/python/ops/optimizer.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The KFAC optimizer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import tensorflow.compat.v1 as tf from kfac.python.ops import curvature_matrix_vector_products as cmvp from kfac.python.ops import estimator as est from kfac.python.ops import fisher_factors as ff from kfac.python.ops import utils as utils ip = utils.ip ip_p = utils.ip_p sprod = utils.sprod sprod_p = utils.sprod_p # If True we the damping contribution is included in the quadratic model for # the purposes of computing qmodel_change in rho (the reduction ratio used in # the LM damping adjustment method). Note that the extra damping from the # "l2_reg" argument is always included. _INCLUDE_DAMPING_IN_QMODEL_CHANGE = False def set_global_constants(include_damping_in_qmodel_change=None): """Sets various global constants used by the classes in this module.""" global _INCLUDE_DAMPING_IN_QMODEL_CHANGE if include_damping_in_qmodel_change is not None: _INCLUDE_DAMPING_IN_QMODEL_CHANGE = include_damping_in_qmodel_change class KfacOptimizer(tf.train.GradientDescentOptimizer): """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" def __init__(self, learning_rate, damping, layer_collection, cov_ema_decay=0.95, var_list=None, momentum=0.9, momentum_type="adam", use_weight_decay=False, weight_decay_coeff=0.1, qmodel_update_rescale=None, norm_constraint=None, name="KFAC", estimation_mode="gradients", colocate_gradients_with_ops=True, batch_size=None, placement_strategy=None, compute_params_stats=False, adapt_damping=False, update_damping_immediately=True, is_chief=True, prev_train_batch=None, loss=None, loss_fn=None, min_damping=1e-8, # this value is somewhat arbitrary l2_reg=0.0, damping_adaptation_decay=0.95, damping_adaptation_interval=5, damping_decrease_rho_threshold=0.75, damping_increase_rho_threshold=0.25, precon_damping_mult=1.0, use_passed_loss=True, train_batch=None, print_logs=False, tf_replicator=None, dtype="float32", **kwargs): """Initializes the K-FAC optimizer with the given settings. NOTE: this is a base class for K-FAC optimizers that offers full control over the execution of K-FAC's various ops. For a more fool-proof / automated version see for example PeriodicInvCovUpdateKfacOpt. Also, please keep in mind that while the K-FAC code loosely conforms to TensorFlow's Optimizer API it can't be used naively as a "drop in replacement" for basic classes like MomentumOptimizer. Using it properly with SyncReplicasOptimizer, for example, requires special care. When using it with Distribution Strategy, unlike common practice, K-FAC expects a loss tensor that is normalized by the per-replica batch size, and *not* by the total batch size (like you may see in TF Distribution Strategy tutorials). Regardless of whether you are using estimator, strategy, or a normal custom training loop, you should pass in the same loss. See the various examples in the "examples" directory for a guide about how to use K-FAC in various contexts and various systems, like TF-Estimator. See in particular the "convnet" example. google/examples also contains an example using TPUEstimator. Args: learning_rate: float or 0D Tensor. The base learning rate for the optimizer. Must be set to None if using one of the 'qmodel' momentum_type values. damping: float or 0D Tensor. This quantity times the identity matrix is (approximately) added to the curvature matrix (i.e. the Fisher or GGN) before it is inverted multiplied by the gradient when computing the (raw) update. This quantity should match the scale of the objective, so that if you put a multiplier on your loss you should apply the same multiplier to the damping. Roughly speaking, larger values constrain the update vector to a smaller region around zero, which we want to do when our local quadratic model is a less trustworthy local approximation of the true objective. The damping value is closely related to the trust region radius and to the classical Tikhonov regularization method. If the `adapt_damping` argument is True then this value is used only as an initial value for the adaptation method. layer_collection: The layer collection object, which holds the Fisher blocks, Kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. (Default: 0.95) var_list: Optional list or tuple of variables to train. Defaults to tf.trainable_variables. momentum: The momentum decay constant to use. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0.9) momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', 'qmodel', or 'qmodel_fixedmu'. 'regular' gives standard momentum. 'adam' gives a style of momentum reminisent of the Adam method, which seems to work better in practice. 'qmodel' makes the optimizer perform automatic control of both the learning rate and momentum using a quadratic model based method (see _compute_qmodel_hyperparams for more details). 'qmodel_fixedmu' is similar to 'qmodel' but only controls the learning rate. (Default: 'adam') use_weight_decay: If True, explicit "weight decay" is performed by K-FAC. Note that this is distinct from L2 regularization, and corresponds to optimizing a regularized version of the loss passed to minimize(), where the regularization term added is related to the "Fisher-Rao norm". See https://openreview.net/pdf?id=B1lz-3Rct7 for more details. Note that using this feature won't change the loss function you pass to minimize(), and thus the loss you report will not correspond precisely to what K-FAC is optimizing. (Default: False) weight_decay_coeff: The coefficient to use for weight decay (see above). (Default: 0.1) qmodel_update_rescale: float or None. An additional multiplier to apply to the update computed by the quadratic model based adjustment methods. If None it will behave like a value of 1.0. (Default: None) norm_constraint: float or Tensor. If specified, the update is scaled down so that its approximate squared Fisher norm v^T F v is at most the specified value. May only be used with momentum type 'regular'. See the docstring for the method _clip_updates() for a more detailed explanation of this feature. (Default: None) name: The name for this optimizer. (Default: 'KFAC') estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be 'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN', 'exact', or 'exact_GGN'. See the doc-string for FisherEstimator (in estimator.py) for more a more detailed description of these options. (Default: 'gradients'). colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True) batch_size: The size of the mini-batch. Only needed when `momentum_type` == 'qmodel' or when `compute_params_stats` is True. Note that when using data parallelism where the model graph and optimizer are replicated across multiple devices, this should be the per-replica batch size. An example of this is sharded data on the TPU, where batch_size should be set to the total batch size divided by the number of shards. (Default: None) placement_strategy: string or None. Device placement strategy used when creating variables, and various ops. Can be None, 'round_robin', or 'replica_round_robin'. 'round_robin' supports round-robin placement of various ops on lists of provided devices. 'replica_round_robin' does something similar but over shards/replicas instead, and only works in certain 'replicated' contexts (e.g. TPUEstimator). The details of the different placement strategies are controlled by additional keyword arguments that can be passed to this class, and which are described in the different placement mixin classes in placement.py. (Default: None) compute_params_stats: Bool. If True, we compute the first order version of the statistics computed to estimate the Fisher/GGN. These correspond to the `variables` method in a one-to-one fashion. They are available via the `params_stats` property. When estimation_mode is 'empirical', this will correspond to the standard parameter gradient on the loss. (Default: False) adapt_damping: `Boolean`. If True we adapt the damping according to the Levenberg-Marquardt rule described in Section 6.5 of the original K-FAC paper. The details of this scheme are controlled by various additional arguments below. Also some of these arguments are extra pieces of information, such as the loss, needed by the method. Note that unless using a convenience subclass like PeriodicInvCovUpdateKfacOpt the damping adaptation op must be executed by the user (like the cov and inv ops). This op is returned by the maybe_pre_update_adapt_damping() method. (Default: False) update_damping_immediately: Damping adjustment strategy. If True then the damping is updated in the same optimizer minimize call as `(step+1) % damping_adaptation_interval == 0`, immediately after the parameter update is performed. If False then the damping is updated in the next step. If True then it is assumed that the apply_gradients op will safely update the model before returning; it is recommended to only resource variables in this case. (Default: True) is_chief: `Boolean`, `True` if the worker is chief. (Default: True) prev_train_batch: Training mini-batch used in the previous step. This will be used to evaluate loss by calling `loss_fn(prev_train_batch)` when damping adaptation is used. (Default: None) loss: `Tensor` the model loss, used as the pre-update loss in adaptive damping. Also used for the built-in log printing. When using Distribution Strategy, unlike common Distribution Strategy practice, this loss tensor should by normalized by the per-replica batch size and NOT the total batch size. (Default: None) loss_fn: `function` that takes as input training data tensor and returns a scalar loss. Only needed if using damping adaptation. When using Distribution Strategy, unlike common Distribution Strategy practice, the loss should by normalized by the per-replica batch size and NOT the total batch size. (Default: None) min_damping: `float`, Minimum value the damping parameter can take. Note that the default value of 1e-8 is quite arbitrary, and you may have to adjust this up or down for your particular problem. If you are using a non-zero value of l2_reg you *may* be able to set this to zero. (Default: 1e-8) l2_reg: `float` or 0D Tensor. Set this value to tell the optimizer what L2 regularization coefficient you are using (if any). Note the coefficient appears in the regularizer as coeff / 2 * sum(param**2), as the thing you multiply tf.nn.l2(param) by. This will be essentially added to the minimum damping, but also included in the qmodel change computations (used for adjusting the damping) even when _INCLUDE_DAMPING_IN_QMODEL_CHANGE is False. Note that the user is still responsible for adding regularization to the loss. (Default: 0.0) damping_adaptation_decay: `float`. The `damping` parameter is multiplied by the `damping_adaptation_decay` every `damping_adaptation_interval` number of iterations. (Default: 0.99) damping_adaptation_interval: `int`. Number of steps in between updating the `damping` parameter. Note that damping is adapted at the step where (step+1) % damping_adaptation_interval == 0, (or immediately at the start of the next step by maybe_pre_update_adapt_damping() if update_damping_immediately is False). (Default: 5) damping_decrease_rho_threshold: `int`. The threshold for rho above which we decrease the damping when using automatic damping adaptation. (Default: 0.75) damping_increase_rho_threshold: `int`. The threshold for rho below which we increase the damping when using automatic damping adaptation. (Default: 0.25) precon_damping_mult: `float`. A multiplier used on the damping value passed to the preconditioner (vs the quadratic model when using momentum_type 'qmodel'). (Default: 1.0) use_passed_loss: `Boolean`. If True we use the loss tensor passed in by the user (via minimze() or compute_gradients() or the set_loss() method) in damping adaptation scheme, instead of calling loss_fn() a second time for this. This is more efficient but may not always be desired. (Default: True) train_batch: Training mini-batch used in the current step. This will be used to evaluate loss by calling `loss_fn(train_batch)` when damping adaptation is used and `use_passed_loss` is False. (Default: None) print_logs: `Boolean`. If True, we print some logging info using tf.print after each iteration. This is done in the method _maybe_print_logging_info, which we encourage you to modify in order to add whatever you want. (Default: False) tf_replicator: A Replicator object or None. If not None, K-FAC will set itself up to work inside of the provided TF-Replicator object. (Default: None) dtype: TF dtype or string representing one. dtype used for scalar properties (rho, etc). (Default: "float32") **kwargs: Arguments to be passed to specific placement strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: ValueError: If the momentum type is unsupported. ValueError: If clipping is used with momentum type other than 'regular'. ValueError: If no losses have been registered with layer_collection. ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ dtype = tf.dtypes.as_dtype(dtype) self._dtype = dtype self._layers = layer_collection self._colocate_gradients_with_ops = colocate_gradients_with_ops momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel", "qmodel_fixedmu"] if momentum_type not in legal_momentum_types: raise ValueError("Unsupported momentum type {}. Must be one of {}." .format(momentum_type, legal_momentum_types)) if momentum_type not in ["regular", "adam"] and norm_constraint is not None: raise ValueError("Update clipping is only supported with momentum " "type 'regular' and 'adam'.") if momentum_type == "qmodel" and momentum is not None: raise ValueError("Momentum must be None if using a momentum_type " "'qmodel'.") self._momentum_type = momentum_type self._momentum = momentum self._use_weight_decay = use_weight_decay self._weight_decay_coeff = weight_decay_coeff self._norm_constraint = norm_constraint self._batch_size = batch_size self._placement_strategy = placement_strategy # Damping adaptation parameters self._adapt_damping = adapt_damping if self._adapt_damping: with tf.variable_scope(name): self._damping = tf.get_variable( "damping", initializer=lambda: tf.constant(damping, dtype=dtype), trainable=False, use_resource=True, dtype=dtype) else: self._damping = damping self._update_damping_immediately = update_damping_immediately self._is_chief = is_chief self._prev_train_batch = prev_train_batch self._loss_tensor = loss self._loss_fn = loss_fn self._damping_adaptation_decay = damping_adaptation_decay self._damping_adaptation_interval = damping_adaptation_interval self._omega = ( self._damping_adaptation_decay**self._damping_adaptation_interval) self._min_damping = min_damping self._use_passed_loss = use_passed_loss if not use_passed_loss and train_batch is None: raise ValueError("Must pass in train_batch if used_passed_loss is false.") self._damping_decrease_rho_threshold = damping_decrease_rho_threshold self._damping_increase_rho_threshold = damping_increase_rho_threshold self._train_batch = train_batch self._print_logs = print_logs self._l2_reg = l2_reg self._precon_damping_mult = precon_damping_mult if self._momentum_type.startswith("qmodel"): if learning_rate is not None: raise ValueError("'learning_rate' must be set to None if using one of " "the 'qmodel' momentum types.") if qmodel_update_rescale is not None: learning_rate = qmodel_update_rescale else: learning_rate = 1.0 else: if learning_rate is None: raise ValueError("'learning_rate' must *not* be set to None unless " "using one of the 'qmodel' momentum types.") if qmodel_update_rescale is not None: raise ValueError("'qmodel_update_rescale' must be None unless using " "one of the 'qmodel' momentum types.") self._qmodel_update_rescale = qmodel_update_rescale with tf.variable_scope(name): nan_init = lambda: tf.constant(float("nan"), dtype=dtype) # We store rho only for possible logging purposes. self._rho = tf.get_variable( "rho", initializer=nan_init, dtype=dtype, trainable=False, use_resource=True) self._prev_loss = tf.get_variable( "prev_loss", initializer=nan_init, dtype=dtype, trainable=False, use_resource=True) self._qmodel_learning_rate = tf.get_variable( "qmodel_learning_rate", initializer=nan_init, dtype=dtype, trainable=False, use_resource=True) self._qmodel_momentum = tf.get_variable( "qmodel_momentum", initializer=nan_init, dtype=dtype, trainable=False, use_resource=True) self._qmodel_change = tf.get_variable( "qmodel_change", initializer=nan_init, dtype=dtype, trainable=False, use_resource=True) self._counter = tf.get_variable( "counter", dtype=tf.int64, shape=(), trainable=False, initializer=tf.zeros_initializer, use_resource=True, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) variables = var_list or tf.trainable_variables() if tf_replicator is not None or tf.distribute.has_strategy(): def _get_sanitized_name(var_name): return re.sub(r"replica_\d+_", "", var_name) # This tells K-FAC's libraries that we are using TF-Replicator with this # particular Replicator object. utils.set_global_constants(tf_replicator=tf_replicator) # We need to sanitize the names of the variables that K-FAC creates # so they are the same between replicas. ff.set_global_constants(get_sanitized_name_fn=_get_sanitized_name) self._fisher_est = est.make_fisher_estimator( placement_strategy=placement_strategy, variables=variables, cov_ema_decay=cov_ema_decay, damping=self._damping * self._precon_damping_mult, layer_collection=self.layers, exps=(-1,), estimation_mode=estimation_mode, colocate_gradients_with_ops=self._colocate_gradients_with_ops, compute_params_stats=compute_params_stats, batch_size=batch_size, **kwargs) super(KfacOptimizer, self).__init__(learning_rate, name=name) def get_cov_vars(self): """Returns all covaraiance varaiables.""" return self._fisher_est.get_cov_vars() def get_inv_vars(self): """Returns all inverse computation related varaiables.""" return self._fisher_est.get_inv_vars() @property def factors(self): return self._fisher_est.factors @property def registered_variables(self): return self._fisher_est.variables @property def layers(self): return self._layers @property def mat_type(self): return self._fisher_est.mat_type @property def damping(self): if self._adapt_damping: return tf.identity(self._damping) else: return tf.convert_to_tensor(self._damping) @property def damping_adaptation_interval(self): return self._damping_adaptation_interval @property def learning_rate(self): if self._momentum_type.startswith("qmodel"): return self._learning_rate * tf.identity(self._qmodel_learning_rate) else: return tf.convert_to_tensor(self._learning_rate) @property def momentum(self): if self._momentum_type.startswith("qmodel"): return tf.identity(self._qmodel_momentum) else: return tf.convert_to_tensor(self._momentum) @property def rho(self): return tf.identity(self._rho) @property def qmodel_change(self): return tf.identity(self._qmodel_change) @property def counter(self): return tf.identity(self._counter) @property def params_stats(self): return self._fisher_est.params_stats def set_loss(self, loss): # Use this method if you have overridden both the minimize method and # compute_gradients method but still want K-FAC to know the loss value # (which is required for damping adaptation). self._loss_tensor = loss def _maybe_print_logging_info(self): if not self._print_logs: return tf.no_op() p = [] p.append(("=========================================================",)) p.append(("Iteration:", self.counter)) p.append(("mini-batch loss =", self._loss_tensor)) p.append(("learning_rate =", self.learning_rate, "| momentum =", self.momentum)) p.append(("damping =", self.damping, "| rho =", self.rho, "| qmodel_change =", self.qmodel_change)) p.append(("=========================================================",)) return utils.multiline_print(p) def make_vars_and_create_op_thunks(self): """Make vars and create op thunks. Returns: cov_update_thunks: List of cov update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. inv_update_thunks: List of inv update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. """ scope = self.get_name() + "/" + self._fisher_est.name return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) def create_ops_and_vars_thunks(self): """Create thunks that make the ops and vars on demand. This function returns 4 lists of thunks: cov_variable_thunks, cov_update_thunks, inv_variable_thunks, and inv_update_thunks. The length of each list is the number of factors and the i-th element of each list corresponds to the i-th factor (given by the "factors" property). Note that the execution of these thunks must happen in a certain partial order. The i-th element of cov_variable_thunks must execute before the i-th element of cov_update_thunks (and also the i-th element of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks must execute before the i-th element of inv_update_thunks. TL;DR (oversimplified): Execute the thunks according to the order that they are returned. Returns: cov_variable_thunks: A list of thunks that make the cov variables. cov_update_thunks: A list of thunks that make the cov update ops. inv_variable_thunks: A list of thunks that make the inv variables. inv_update_thunks: A list of thunks that make the inv update ops. """ scope = self.get_name() + "/" + self._fisher_est.name return self._fisher_est.create_ops_and_vars_thunks(scope=scope) def check_var_list(self, var_list): if set(var_list) != set(self.registered_variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables (i.e. those that were registered).") @staticmethod def _scale_loss(loss_value): # tf.compat.v1.train.Optimizer uses this method to account for the Estimator # + Distribution Strategy (DS) case. DS wants a scaled loss and to aggregate # gradients via a sum. Estimator provides an unscaled loss by default. So, # this method would divide the loss by num_replicas. For our optimizer, we # require users to pass in an unscaled loss, so we do not want this method # to alter Estimator's input when it's used with DS. return loss_value def minimize(self, loss, global_step=None, var_list=None, gate_gradients=tf.train.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=True, name=None, grad_loss=None, **kwargs): # This method has the same general arguments as the minimize methods in # standard optimizers do. # With most optimizers used with Distribution Strategy (DS), the user is # expected to scale their loss by 1.0 / global_batch_size, then DS # aggregates the gradients via a sum. We expect users to pass in a loss that # is normalized by the per-replica batch size only. This is so we can # handle the Estimator and DS cases in a consistent way. As a side effect, # this means each replica must have the same per-replica batch size. if var_list is None: var_list = self.registered_variables else: self.check_var_list(var_list) return super(KfacOptimizer, self).minimize( loss, global_step=global_step, var_list=var_list, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, grad_loss=grad_loss, **kwargs) def compute_gradients(self, loss, var_list=None, gate_gradients=tf.train.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=True, grad_loss=None, **kwargs): # This method has the same general arguments as the minimize methods in # standard optimizers do. Unlike the compute_gradient method for typical # optimizer implementations, this one performs cross-replica syncronization # automatically when under one the supported replicated contexts, and so # use of things like CrossShardOptimizer is unessesary (and wasteful). if var_list is not None: self.check_var_list(var_list) grads_and_vars = super(KfacOptimizer, self).compute_gradients( loss=loss, var_list=var_list, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, grad_loss=grad_loss, **kwargs) # When using the TF Keras fused BatchNormalization implementation, in some # cases the gradient shape is ?. KFAC needs the gradient shape in at least # two cases: when registering a layer as generic, and when computing the # qmodel. The gradient should have the same shape as the variable, so when # any dimension is None we set the shape ourselves. for grad, var in grads_and_vars: if len(grad.shape) and not all(grad.shape.as_list()): grad.set_shape(var.shape) grads, vars_ = list(zip(*grads_and_vars)) grads = utils.all_average(grads) return tuple(zip(grads, vars_)) def _is_damping_adaptation_time(self): # Note that we update damping at the step right before the end of the # interval, instead of at the beginning of the next interval. This is # so it properly lines up with the periodic inverse updates (i.e. happens # immediately before them.) return tf.equal(tf.mod(self.counter + 1, self._damping_adaptation_interval), 0) def _is_just_after_damping_adaptation_time(self): is_just_after = tf.equal( tf.mod(self.counter, self._damping_adaptation_interval), 0) return tf.logical_and(is_just_after, tf.not_equal(self.counter, 0)) def _maybe_update_prev_loss(self): if self._adapt_damping: should_update_prev_loss = self._is_damping_adaptation_time() def update_prev_loss(): loss = self._loss_tensor if self._use_passed_loss else self._loss_fn( self._train_batch) loss = utils.all_average(loss) return tf.group(utils.smart_assign(self._prev_loss, loss, force_cast=True)) maybe_update_prev_loss_op = tf.cond( should_update_prev_loss, update_prev_loss, tf.no_op) return maybe_update_prev_loss_op else: return tf.no_op() def maybe_pre_update_adapt_damping(self): """Maybe adapt the damping according to the built-in scheme. Unless using a convenience class like PeriodicInvCovUpdateKfacOpt the op returned by this function should be run every sess.run call, preferably before the inv ops (using a control dependency). Returns: An op that applies the specified gradients, and also updates the counter variable. """ if (not self._adapt_damping or not self._is_chief or self._update_damping_immediately): return tf.no_op() # We update the damping on the iteration that is technically after # where we compute qmodel_change. However, it should happen before # anything else does, so it's as if we computed it on the previous # iteration. The only reason we do it this way and not on the # actual iteration is due to weirdness related to parameter servers # or possibly just non-resource variables. Essentially, the model # variables won't be updated and so we can't properly compute # prev_batch_loss until the next sess.run() call. should_update_damping = self._is_just_after_damping_adaptation_time() maybe_update_damping = tf.cond( should_update_damping, self._update_damping, tf.no_op) return maybe_update_damping def _maybe_post_update_adapt_damping(self): if not self._update_damping_immediately or not self._adapt_damping: return tf.no_op() should_update_damping = self._is_damping_adaptation_time() maybe_update_damping = tf.cond( should_update_damping, self._update_damping, tf.no_op) return maybe_update_damping def apply_gradients(self, grads_and_vars, *args, **kwargs): """Apply updates to variables. Args: grads_and_vars: List of (gradient, variable) pairs. *args: Additional arguments for super.apply_gradients. **kwargs: Additional keyword arguments for super.apply_gradients. Returns: An op that applies the specified gradients, and also updates the counter variable. """ maybe_update_prev_loss = self._maybe_update_prev_loss() with tf.control_dependencies([maybe_update_prev_loss]): # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. grads_and_vars = list(grads_and_vars) with tf.variable_scope(self.get_name()): # Compute raw update step (self._learning_rate not yet applied). # Note that this function also updates the velocity vectors. raw_updates_and_vars = self._compute_raw_update_steps(grads_and_vars) if self._use_weight_decay: raw_updates_and_vars = self._add_weight_decay(raw_updates_and_vars) if tf.distribute.has_strategy(): # Distribution Strategy (DS) expects users to pass in loss / # global_batch_size to minimize. We require users not to do this, so our # code can consistently deal with input in the single device, Estimator, # and DS cases. However, the _distributed_apply call in # super(...).apply_gradients(...) will perform a sum over replicas to # aggregate the gradients. Therefore, we divide by the number of # replicas so that the scaling of the update applied to the variables # is correct. num_replicas = tf.distribute.get_strategy().num_replicas_in_sync raw_updates_and_vars = [(update/num_replicas, var) for update, var in raw_updates_and_vars] # Update trainable variables with this step, applying self._learning_rate. apply_op = super(KfacOptimizer, self).apply_gradients( raw_updates_and_vars, *args, **kwargs) with tf.control_dependencies([apply_op]): maybe_post_update_damping_op = self._maybe_post_update_adapt_damping() with tf.control_dependencies([maybe_post_update_damping_op]): maybe_print_logging_info = self._maybe_print_logging_info() with tf.control_dependencies([maybe_print_logging_info]): # Update the main counter return tf.group( utils.smart_assign(self._counter, 1, assign_fn=tf.assign_add)) def _add_weight_decay(self, grads_and_vars): """Applies weight decay. Args: grads_and_vars: List of (gradient, variable) pairs. Returns: List of (gradient, variable) pairs. """ return [(grad + self._weight_decay_coeff * tf.stop_gradient(var), var) for grad, var in grads_and_vars] def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): """Computes the squared (approximate) Fisher norm of the updates. This is defined as v^T F v, where F is the approximate Fisher matrix as computed by the estimator, and v = F^{-1} g, where g is the gradient. This is computed efficiently as v^T g. Args: grads_and_vars: List of (gradient, variable) pairs. precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. Must be the result of calling `self._multiply_preconditioner` on `grads_and_vars`. Returns: Scalar representing the squared norm. Raises: ValueError: if the two list arguments do not contain the same variables, in the same order. """ return ip_p(grads_and_vars, precon_grads_and_vars) def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): """Computes the scale factor for the update to satisfy the norm constraint. Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint, F is the approximate Fisher matrix, and r is the update vector, i.e. -alpha * v, where alpha is the learning rate, and v is the preconditioned gradient. This is based on Section 5 of Ba et al., Distributed Second-Order Optimization using Kronecker-Factored Approximations. Note that they absorb the learning rate alpha (which they denote eta_max) into the formula for the coefficient, while in our implementation, the rescaling is done before multiplying by alpha. Hence, our formula differs from theirs by a factor of alpha. Args: grads_and_vars: List of (gradient, variable) pairs. precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. Must be the result of calling `self._multiply_preconditioner` on `grads_and_vars`. Returns: Scalar representing the coefficient which should be applied to the preconditioned gradients to satisfy the norm constraint. """ sq_norm_grad = self._squared_fisher_norm(grads_and_vars, precon_grads_and_vars) sq_norm_up = sq_norm_grad * self._learning_rate**2 return tf.minimum( tf.ones(shape=(), dtype=sq_norm_up.dtype), tf.sqrt(self._norm_constraint / sq_norm_up)) def _clip_updates(self, grads_and_vars, precon_grads_and_vars): """Rescales the preconditioned gradients to satisfy the norm constraint. Rescales the preconditioned gradients such that the resulting update r (after multiplying by the learning rate) will satisfy the norm constraint. This constraint is that r^T F r <= C, where F is the approximate Fisher matrix, and C is the norm_constraint attribute. See Section 5 of Ba et al., Distributed Second-Order Optimization using Kronecker-Factored Approximations. Args: grads_and_vars: List of (gradient, variable) pairs. precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. Must be the result of calling `self._multiply_preconditioner` on `grads_and_vars`. Returns: List of (rescaled preconditioned gradient, variable) pairs. """ coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) return sprod_p(coeff, precon_grads_and_vars) def _compute_prev_updates(self, variables): """Returns the previous update vector computed using the quadratic model. Note that this vector does not include any additional scaling that may have been applied after the quadratic model optimization (i.e. the quantity returned by self.learning_rate). Note that this may not actually be the previous update if momentum_type="adam". Args: variables: List of variables for which to compute the previous update. Returns: List of (previous_update, variable) pairs in the same order as `variables`. """ # What guarantee do we have that this is the old value and not the # new value? Remember that control flow doesn't work in TF whenever # non-resource variables are involved. # TODO(b/121245468): Figure out if this is a problem and if not explain why # Or fix it by somehow forcing the slots to use resource variables instead. prev_updates = sprod( -1., tuple(self._zeros_slot(var, "velocity", self.get_name()) for var in variables)) return tuple(zip(prev_updates, variables)) def _compute_qmodel(self, raw_updates_and_vars, prev_updates_and_vars, grads_and_vars, should_average_over_replicas=True): """Computes the 2 dimensional version of the (exact) quadratic model. The two dimesions are the update and the previous update vectors. The arguments are all lists of (Tensor, Variable) pairs where the variables are the same and in the same order. Args: raw_updates_and_vars: a list of (precond grad, variable) pairs. Raw update proposal to apply to the variables (before scaling by learning rate and addition of velocity/momentum). prev_updates_and_vars: a list of (previous update, variable) pairs. Previous update applied to the variables (includes learning rate and velocity/momentum). grads_and_vars: a list of (gradient, variable) pairs. Gradients for the parameters and the variables that the updates are being applied to. The order of this list must correspond to the order of the other arguments. (Note that this function doesn't actually apply the update.) should_average_over_replicas: a bool. If true, results will be averged over replicas (using utils.all_average). (Default: True) Returns: m, c, and b. m is the 2 by 2 matrix representing the quadratic term, c is a 2 by 1 vector representing the linear term, and b is the 2 by 2 matrix representing only the contribution of the damping to the quadratic term. These are all multi-dimensional lists (lists of lists) of Tensors. """ # Raw update proposal to apply to the variables (before scaling by learning # rate and addition of velocity/momentum). raw_updates, _ = zip(*raw_updates_and_vars) prev_updates, _ = zip(*prev_updates_and_vars) grads, variables = zip(*grads_and_vars) utils.assert_variables_match_pairs_list( raw_updates_and_vars, prev_updates_and_vars, error_message="_compute_qmodel raw_updates_and_vars and " "prev_updates_and_vars differ.") utils.assert_variables_match_pairs_list( prev_updates_and_vars, grads_and_vars, error_message="_compute_qmodel prev_updates_and_vars and " "grads_and_vars differ.") cmvpc = cmvp.CurvatureMatrixVectorProductComputer( self.layers, variables, colocate_gradients_with_ops=self._colocate_gradients_with_ops) # Compute the matrix-vector products with the transposed Fisher factor # (or GGN factor) if self.mat_type == "Fisher": mft_updates = cmvpc.multiply_fisher_factor_transpose(raw_updates) mft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) elif self.mat_type == "GGN" or self.mat_type == "Empirical_Fisher": mft_updates = cmvpc.multiply_ggn_factor_transpose(raw_updates) mft_prev_updates = cmvpc.multiply_ggn_factor_transpose(prev_updates) batch_size = tf.cast(self._batch_size, dtype=mft_updates[0].dtype) damping = tf.cast(self.damping, dtype=raw_updates[0].dtype) b_11 = damping * ip(raw_updates, raw_updates) b_21 = damping * ip(prev_updates, raw_updates) b_22 = damping * ip(prev_updates, prev_updates) b = [[b_11, b_21], [b_21, b_22]] # Compute the entries of the 2x2 matrix m_11 = ip(mft_updates, mft_updates) / batch_size m_21 = ip(mft_prev_updates, mft_updates) / batch_size m_22 = (ip(mft_prev_updates, mft_prev_updates) / batch_size) m = [[m_11 + b_11, m_21 + b_21], [m_21 + b_21, m_22 + b_22]] if should_average_over_replicas: m = utils.all_average(m) c_1 = ip(grads, raw_updates) c_2 = ip(grads, prev_updates) c = [[c_1], [c_2]] return m, c, b @property def _sub_damping_out_qmodel_change_coeff(self): return 1.0 - self._l2_reg / self.damping def _compute_qmodel_hyperparams(self, m, c, b, fixed_mu=None): """Compute optimal update hyperparameters from the quadratic model. More specifically, if L is the loss we minimize a quadratic approximation of L(theta + d) which we denote by qmodel(d) with d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where qmodel(d) = (1/2) * d^T * C * d + grad^T*d + L(theta) . Unlike in the KL clipping approach we use the non-approximated quadratic model where the curvature matrix C is the true Fisher (or GGN) on the current mini-batch (computed without any approximations beyond mini-batch sampling), with the usual Tikhonov damping/regularization applied, C = F + damping * I See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of the formula. See Appendix C for a discussion of the trick of using a factorized Fisher matrix to more efficiently compute the required vector-matrix-vector products. Args: m: 2 by 2 matrix representing the quadratic term (a list of list of 0D Tensors) c: a 2 by 1 vector representing the linear term (a list of 0D Tensors) b: 2 by 2 matrix representing only the contribution of the damping to the quadratic term fixed_mu: A fixed value of mu to use instead of the optimal one. (Default: None) Returns: (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the quadratic model, and qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0) = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). """ def non_zero_prevupd_case(): r"""Computes optimal (alpha, mu) given non-zero previous update. We solve the full 2x2 linear system. See Martens & Grosse (2015), Section 7, definition of $\alpha^*$ and $\mu^*$. Returns: (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the quadratic model, and qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0). """ if fixed_mu is None: sol = -1. * _two_by_two_solve(m, c) alpha = sol[0, 0] mu = sol[1, 0] if self._qmodel_update_rescale is None: # This is a special formula that takes advantage of the particular # relationship of sol to m and c. It should be equivalent to # _eval_quadratic(m, c, sol) if everything is working properly. qmodel_change = 0.5 * tf.reduce_sum(sol * c) else: sol = self._qmodel_update_rescale * sol qmodel_change = _eval_quadratic(m, c, sol) # Subtract out the damping-related penalty if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE: qmodel_change -= (self._sub_damping_out_qmodel_change_coeff * _eval_quadratic_no_c(b, sol)) else: alpha = -1. * (fixed_mu * m[0][1] + c[0][0]) / (m[0][0]) mu = fixed_mu sol = [[alpha], [mu]] if self._qmodel_update_rescale is not None: sol = self._qmodel_update_rescale * tf.convert_to_tensor(sol) qmodel_change = _eval_quadratic(m, c, sol) # Subtract out the damping-related penalty if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE: qmodel_change -= (self._sub_damping_out_qmodel_change_coeff * _eval_quadratic_no_c(b, sol)) return tf.squeeze(alpha), tf.squeeze(mu), tf.squeeze(qmodel_change) def zero_prevupd_case(): r"""Computes optimal (alpha, mu) given all-zero previous update. The linear system reduces to 1x1. See Martens & Grosse (2015), Section 6.4, definition of $\alpha^*$. Returns: (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the quadratic model, and qmodel_change = qmodel(alpha*precon_grad) - qmodel(0) """ alpha = -c[0][0] / m[0][0] if fixed_mu is None: mu = 0.0 else: mu = fixed_mu mu = tf.cast(mu, dtype=alpha.dtype) if self._qmodel_update_rescale is None: # This is a special formula that takes advantage of the particular # relationship of sol to m and c. qmodel_change = 0.5 * alpha * c[0][0] # Subtract out the damping-related penalty if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE: qmodel_change -= (self._sub_damping_out_qmodel_change_coeff * 0.5 * tf.square(alpha) * b[0][0]) else: sol = self._qmodel_update_rescale * alpha qmodel_change = 0.5 * m[0][0] * tf.square(sol) + c[0][0] * sol # Subtract out the damping-related penalty if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE: qmodel_change -= (self._sub_damping_out_qmodel_change_coeff * 0.5 * tf.square(sol) * b[0][0]) return alpha, mu, qmodel_change return tf.cond( tf.equal(c[1][0], 0.0), zero_prevupd_case, non_zero_prevupd_case) def _compute_approx_qmodel_change(self, updates_and_vars, grads_and_vars): """Computes the change in the approximate quadratic model. 'Approximate' means the quadratic model which uses the approximate Fisher/GGN as the curvature matrix, instead of the exact Fisher/GGN which is used by _compute_qmodel and its dependent methods. Args: updates_and_vars: List of (update, variable) pairs. grads_and_vars: List of (gradient, variable) pairs. Returns: A 0D Tensor which is the change in the approximate quadratic model. """ quad_term = 0.5*ip_p(updates_and_vars, self._fisher_est.multiply(updates_and_vars)) if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE: # This isn't quite right, but doing it properly is too awkward. quad_term -= (self._sub_damping_out_qmodel_change_coeff * 0.5 * self.damping*ip_p(updates_and_vars, updates_and_vars)) linear_term = ip_p(updates_and_vars, grads_and_vars) return quad_term + linear_term def _maybe_update_qmodel_change(self, qmodel_change_thunk): """Returns an op which updates the qmodel_change variable if it is time to. Args: qmodel_change_thunk: A callable which when evaluated returns the qmodel change. Returns: An op. """ def update_qmodel_change(): # The tf.group is needed to strip away the value so it can be used # in the cond later. return tf.group(utils.smart_assign(self._qmodel_change, tf.squeeze(qmodel_change_thunk()), force_cast=True)) # Note that we compute the qmodel change and store it in a variable so # it can be used at the next sess.run call (where rho will actually be # computed). return tf.cond(self._is_damping_adaptation_time(), update_qmodel_change, tf.no_op) def _multiply_preconditioner(self, vecs_and_vars): return self._fisher_est.multiply_inverse(vecs_and_vars) def _get_qmodel_quantities(self, grads_and_vars): # Compute "preconditioned gradient". precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars) var_list = tuple(var for (_, var) in grads_and_vars) prev_updates_and_vars = self._compute_prev_updates(var_list) # While it might seem like this call performs needless computations # involving prev_updates_and_vars in the case where it is zero, because # we extract out only the part of the solution that is not zero the rest # of it will not actually be computed by TensorFlow (I think). m, c, b = self._compute_qmodel( precon_grads_and_vars, prev_updates_and_vars, grads_and_vars) return precon_grads_and_vars, m, c, b def _compute_raw_update_steps(self, grads_and_vars): """Computes the raw update steps for the variables given the gradients. Note that these "raw updates" are further multiplied by -1*self._learning_rate when the update is eventually applied in the superclass (which is GradientDescentOptimizer). Args: grads_and_vars: List of (gradient, variable) pairs. Returns: A list of tuples (raw_update, var) where raw_update is the update to the parameter. These updates must be actually used since they carry with them certain control dependencies that need to happen. """ if self._momentum_type == "regular": # Compute "preconditioned" gradient. precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars) # Apply "KL clipping" if asked for. if self._norm_constraint is not None: precon_grads_and_vars = self._clip_updates(grads_and_vars, precon_grads_and_vars) # Update the velocities and get their values as the "raw" updates raw_updates_and_vars = self._update_velocities(precon_grads_and_vars, self._momentum) if self._adapt_damping and self._is_chief: def compute_qmodel_change(): updates_and_vars = sprod_p(-1. * self._learning_rate, raw_updates_and_vars) return self._compute_approx_qmodel_change(updates_and_vars, grads_and_vars) maybe_update_qmodel_change = self._maybe_update_qmodel_change( compute_qmodel_change) with tf.control_dependencies([maybe_update_qmodel_change]): # Making this a tuple is important so that it actually gets evaluated # in the context. return tuple((tf.identity(vec), var) for (vec, var) in raw_updates_and_vars) else: return raw_updates_and_vars elif self._momentum_type == "adam": velocities_and_vars = self._update_velocities(grads_and_vars, self._momentum) # The "preconditioned" velocity vector is the raw update step. raw_updates_and_vars = self._multiply_preconditioner(velocities_and_vars) # Apply "KL clipping" if asked for. Note that we are applying this to # the combined preconditioned gradient + velocity, unlike for the # momentum_type = 'regular' case. if self._norm_constraint is not None: raw_updates_and_vars = self._clip_updates(velocities_and_vars, raw_updates_and_vars) if self._adapt_damping and self._is_chief: def compute_qmodel_change(): # This is a special formula that exploits the structure of the # particular update we are using. Note that this is using the approx # Fisher as defined by the inverses, which might be stale (perhaps so # stale that they are using an old damping value, which may mess up # the damping adaptation method). return (0.5 * (self._learning_rate**2) * ip_p(raw_updates_and_vars, velocities_and_vars) - self._learning_rate * ip_p(raw_updates_and_vars, grads_and_vars)) maybe_update_qmodel_change = self._maybe_update_qmodel_change( compute_qmodel_change) with tf.control_dependencies([maybe_update_qmodel_change]): # Making this a tuple is important so that it actually gets evaluated # in the context. return tuple((tf.identity(vec), var) for (vec, var) in raw_updates_and_vars) else: return raw_updates_and_vars elif (self._momentum_type == "qmodel" or self._momentum_type == "qmodel_fixedmu"): precon_grads_and_vars, m, c, b = self._get_qmodel_quantities( grads_and_vars) if self._momentum_type == "qmodel_fixedmu": fixed_mu = self._momentum else: fixed_mu = None # Compute optimal velocity update parameters according to quadratic # model alpha, mu, qmodel_change = self._compute_qmodel_hyperparams( m, c, b, fixed_mu=fixed_mu) qmodel_assign_op = tf.group( utils.smart_assign(self._qmodel_change, qmodel_change, force_cast=True), utils.smart_assign(self._qmodel_learning_rate, -alpha, force_cast=True), utils.smart_assign(self._qmodel_momentum, mu, force_cast=True)) with tf.control_dependencies([qmodel_assign_op]): return self._update_velocities( precon_grads_and_vars, mu, vec_coeff=-alpha) # NOTE: the very particular way this function is written is probably important # for it to work correctly with non-resource variables, which are very # unpredictable with regards to control flow. def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): """Updates the velocities of the variables with the given vectors. Args: vecs_and_vars: List of (vector, variable) pairs. decay: How much to decay the old velocity by. This is often referred to as the 'momentum constant'. vec_coeff: Coefficient to apply to the vectors before adding them to the velocity. Returns: A list of (velocity, var) indicating the new velocity for each var. """ def _update_velocity(vec, var): velocity = self._zeros_slot(var, "velocity", self.get_name()) with tf.colocate_with(velocity): # NOTE(mattjj): read/modify/write race condition not suitable for async. # Compute the new velocity for this variable. new_velocity = decay * velocity + vec_coeff * vec # Save the updated velocity. return (tf.identity(utils.smart_assign(velocity, new_velocity)), var) # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] def _get_current_loss(self): if self._update_damping_immediately: return utils.all_average(self._loss_fn(self._train_batch)) return utils.all_average(self._loss_fn(self._prev_train_batch)) def _get_prev_loss(self): return tf.identity(self._prev_loss) def _update_damping(self): """Adapts damping parameter. Check KFAC paper (Section 6.5) for the details. The damping parameter is updated according to the Levenberg-Marquardt rule every `self._damping_adaptation_interval` iterations. Essentially, the rule computes the reduction ratio "rho" and depending on the value either increases lambda, decreases it, or leaves it as is. The reduction ratio captures how closely the quadratic approximation to the loss function approximates the actual loss within a trust region. The damping update tries to make the damping as small as possible while maintaining the property that the quadratic model remains a good local approximation to the loss function. Returns: An Op to assign newly computed damping value to `self._damping`, and also updates the _rho member. """ prev_loss = self._get_prev_loss() current_loss = tf.cast(self._get_current_loss(), dtype=prev_loss.dtype) loss_change = current_loss - prev_loss rho = loss_change / self._qmodel_change should_decrease = tf.math.logical_or( tf.math.logical_and(loss_change < 0, self._qmodel_change > 0), rho > self._damping_decrease_rho_threshold) should_increase = rho < self._damping_increase_rho_threshold new_damping = tf.case( [(should_decrease, lambda: self.damping * self._omega), (should_increase, lambda: self.damping / self._omega)], default=lambda: self.damping) new_damping = tf.maximum(new_damping, self._min_damping + self._l2_reg) return tf.group(utils.smart_assign(self._damping, new_damping), utils.smart_assign(self._rho, rho, force_cast=True)) def _two_by_two_solve(m, vec): """Solve a 2x2 system by direct inversion. Args: m: A length 2 list of length 2 lists, is a 2x2 matrix of [[a, b], [c, d]]. vec: The length 2 list of length 1 lists, a vector of [e, f]. Returns: matmul(m^{-1}, vec). """ a = m[0][0] b = m[0][1] c = m[1][0] d = m[1][1] inv_m_det = 1.0 / (a * d - b * c) m_inverse = [ [d * inv_m_det, -b * inv_m_det], [-c * inv_m_det, a * inv_m_det] ] return tf.matmul(m_inverse, vec) def _eval_quadratic_no_c(m, vec): return 0.5*tf.matmul(tf.matmul(vec, m, transpose_a=True), vec) def _eval_quadratic(m, c, vec): return _eval_quadratic_no_c(m, vec) + tf.matmul(c, vec, transpose_a=True) ================================================ FILE: kfac/python/ops/placement.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implements placement strategies for various ops and variables.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import itertools # Dependency imports import tensorflow.compat.v1 as tf from tensorflow.python.util import nest from kfac.python.ops import utils as utils def _make_thunk_on_device(func, device): def thunk(*args, **kwargs): with tf.device(device): return func(*args, **kwargs) return thunk class RoundRobinPlacementMixin(object): """Implements round robin placement strategy for ops and variables.""" def __init__(self, cov_devices=None, inv_devices=None, trans_devices=None, **kwargs): """Create a RoundRobinPlacementMixin object. Args: cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None or empty, which means that no devices are specified. inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None or empty, which means that no devices are specified. trans_devices: Iterable of device strings (e.g. '/gpu:0'). Transformation computations (e.g. multiplying different blocks by the inverse Fisher) will be placed on these devices in a round-robin fashion. Can be None or empty, which means that no devices are specified. **kwargs: Pass through arguments. """ super(RoundRobinPlacementMixin, self).__init__(**kwargs) self._cov_devices = cov_devices self._inv_devices = inv_devices self._trans_devices = trans_devices def _place_and_compute_transformation_thunks(self, thunks, params_list): """Computes transformation thunks with round-robin device placement. Device placement done in round-robin fashion according to the order of the `blocks` property, using the list `trans_devices` passed in to the constructor. Args: thunks: A list of thunks to run. Must be in one to one correspondence with the `blocks` property. params_list: A list of the corresponding parameters. Must be in one to one correspondence with the `blocks` property. Returns: A list (in the same order) of the returned results of the thunks, with round-robin device placement applied. """ del params_list if self._trans_devices: results = [] for thunk, device in zip(thunks, itertools.cycle(self._trans_devices)): with tf.device(device): results.append(thunk()) return results else: return tuple(thunk() for thunk in thunks) def create_ops_and_vars_thunks(self, scope=None): """Create thunks that make the ops and vars on demand with device placement. For each factor, all of that factor's cov variables and their associated update ops will be placed on a particular device. A new device is chosen for each factor by cycling through list of devices in the `self._cov_devices` attribute. If `self._cov_devices` is `None` then no explicit device placement occurs. An analogous strategy is followed for inverse update ops, with the list of devices being given by the `self._inv_devices` attribute. Inverse variables on the other hand are not placed on any specific device (they will just use the current the device placement context, whatever that happens to be). The idea is that the inverse variable belong where they will be accessed most often, which is the device that actually applies the preconditioner to the gradient. The user will be responsible for setting the device context for this. This function returns 4 lists of thunks: cov_variable_thunks, cov_update_thunks, inv_variable_thunks, and inv_update_thunks. The length of each list is the number of factors and the i-th element of each list corresponds to the i-th factor (given by the "factors" property). Note that the execution of these thunks must happen in a certain partial order. The i-th element of cov_variable_thunks must execute before the i-th element of cov_update_thunks (and also the i-th element of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks must execute before the i-th element of inv_update_thunks. TL;DR (oversimplified): Execute the thunks according to the order that they are returned. Args: scope: A string or None. If None it will be set to the name of this estimator (given by the name property). All variables will be created, and all thunks will execute, inside of a variable scope of the given name. (Default: None) Returns: cov_variable_thunks: A list of thunks that make the cov variables. cov_update_thunks: A list of thunks that make the cov update ops. inv_variable_thunks: A list of thunks that make the inv variables. inv_update_thunks: A list of thunks that make the inv update ops. """ (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, inv_update_thunks_raw) = self._create_ops_and_vars_thunks(scope=scope) if self._cov_devices: cov_variable_thunks = [] cov_update_thunks = [] for cov_variable_thunk, cov_update_thunk, device in zip( cov_variable_thunks_raw, cov_update_thunks_raw, itertools.cycle(self._cov_devices)): cov_variable_thunks.append(_make_thunk_on_device(cov_variable_thunk, device)) cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, device)) else: cov_variable_thunks = cov_variable_thunks_raw cov_update_thunks = cov_update_thunks_raw inv_variable_thunks = inv_variable_thunks_raw if self._inv_devices: inv_update_thunks = [] for inv_update_thunk, device in zip(inv_update_thunks_raw, itertools.cycle(self._inv_devices)): inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, device)) else: inv_update_thunks = inv_update_thunks_raw return (cov_variable_thunks, cov_update_thunks, inv_variable_thunks, inv_update_thunks) class ReplicaRoundRobinPlacementMixin(object): """Implements round robin placement strategy for certain ops on replicas. This placement strategy can be used in certain TPU training systems, where there are multiple "replicas" of the graph, such as in TPUEstimator or TF-Replicator. The execution of inverse and transformation ops, which by default occurs redundantly on all replicas, are instead distributed over replicas in a round-robin fashion. This is achieved by using tf.cond statements to check the replica id number. This placement strategy doesn't need to be used with TPU training, and may not work with all possible setups (such as TF Replicator). When it does work however, it may provide a substantial improvement in wall-clock time. """ def __init__(self, distribute_transformations=True, **kwargs): """Create a ReplicaRoundRobinPlacementMixin object. Args: distribute_transformations: Bool. If True we distribute certain vector transformations, such as multiplication by the preconditioner, across different replicas. Because this is a cheaper operation it may not always be worth the increase communication cost to do this. (Default: True) **kwargs: Pass through arguments. """ if not utils.is_replicated(): raise ValueError("This placement mode should only be used with certain " "kinds of 'replicated' setups, such as TPUEstimator " "or TF-Replicator.") self._distribute_transformations = distribute_transformations super(ReplicaRoundRobinPlacementMixin, self).__init__(**kwargs) def _place_and_compute_transformation_thunks(self, thunks, params_list): """Computes transformation thunks with round-robin replica placement. Replica placement done in round-robin fashion according to the order of the `blocks` property, cycling through the replicas in numerical order. Args: thunks: A list of thunks to run. Must be in one to one correspondence with the `blocks` property. params_list: A list of the corresponding parameters. Must be in one to one correspondence with the `blocks` property. Returns: A list (in the same order) of the returned results of the thunks, with round-robin replica placement applied. """ del params_list return utils.map_gather(thunks) def create_ops_and_vars_thunks(self, scope=None): """Create op/var-making thunks with replica placement for inverse ops. For each factor in the list of factors, the associated inverse ops will execute on a single replica which is chosen in round-robin fashion. Cov ops are run on all replicas, with the appropriate averaging done by using a few cross_replica_mean's that have been injected into the FisherFactor classes (and execute regardless if this mixin is being used). This function returns 4 lists of thunks: cov_variable_thunks, cov_update_thunks, inv_variable_thunks, and inv_update_thunks. The length of each list is the number of factors and the i-th element of each list corresponds to the i-th factor (given by the "factors" property). (Actually, for inv_update_thunks this class in particular returns only one thunk inside inv_update_thunks that updates all the factors.) Note that the execution of these thunks must happen in a certain partial order. The i-th element of cov_variable_thunks must execute before the i-th element of cov_update_thunks (and also the i-th element of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks must execute before the i-th element of inv_update_thunks. TL;DR (oversimplified): Execute the thunks according to the order that they are returned. Args: scope: A string or None. If None it will be set to the name of this estimator (given by the name property). All variables will be created, and all thunks will execute, inside of a variable scope of the given name. (Default: None) Returns: cov_variable_thunks: A list of thunks that make the cov variables. cov_update_thunks: A list of thunks that make the cov update ops. inv_variable_thunks: A list of thunks that make the inv variables. inv_update_thunks: A list of thunks that make the inv update ops. """ (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, inv_update_thunks_raw) = self._create_ops_and_vars_thunks(scope=scope) cov_variable_thunks = cov_variable_thunks_raw # all_averages of cov values are performed internally in the FisherFactor # classes, so we don't need to do anything for the cov updates here. cov_update_thunks = cov_update_thunks_raw inv_variable_thunks = inv_variable_thunks_raw # The thunks made here execute the supplied inverse update thunk and then # retrieve the values from the corresponding inverse variables. def make_thunk(inv_update_thunk, inv_vars): def thunk(): with tf.control_dependencies([inv_update_thunk()]): return nest.map_structure(tf.identity, inv_vars) return thunk # This single thunk calls map_gather to distribute the work, and then # saves the results back to the corresponding inverse variables. def inv_update_thunk(): assert len(inv_update_thunks_raw) == len(self.factors) # Create a list of factors and thunks that only include the factors # that have inverse variables. Note that not executing the inverse ops of # those that don't shouldn't matter. factors_and_thunks = tuple( (factor, thunk) for factor, thunk in zip(self.factors, inv_update_thunks_raw) if factor.get_inv_vars()) factors, _ = zip(*factors_and_thunks) thunks = tuple( make_thunk(inv_update_thunk, factor.get_inv_vars()) for factor, inv_update_thunk in factors_and_thunks) results = utils.map_gather(thunks) # These assigns save the values back to the variables. ops = (utils.smart_assign(var, val) for factor, result in zip(factors, results) # pylint: disable=g-complex-comprehension for val, var in zip(result, factor.get_inv_vars())) return tf.group(*ops) # Note that we have to return one big inv_update_thunk instead of one for # each factor. This is because utils.map_gather doesn't support returning # thunks (because TFReplicator's map_gather doesn't). inv_update_thunks = [inv_update_thunk] return (cov_variable_thunks, cov_update_thunks, inv_variable_thunks, inv_update_thunks) ================================================ FILE: kfac/python/ops/tensormatch/__init__.py ================================================ ================================================ FILE: kfac/python/ops/tensormatch/graph_matcher.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Pattern matcher for TensorFlow graphs in the Python object model. Writing Python to crawl through TensorFlow graphs can be a pain, and the resulting code is often hard to adapt, extend, and reuse. Instead of hand-writing that code, we should automatically generate it from a simple pattern-matching language. This package provides one such system. More precisely, this package defines a pattern language for matching and extracting nodes from TensorFlow graphs as represented in the Python object model. Patterns can be defined in Python code with a simple syntax and are compiled into compositions of continuation-passing matcher combinators. The mechanism for compiling the pattern language into combinators looks like an analyzing Scheme interpreter. The design comes from GJS's 6.945 at MIT. The pattern language compiler can be extended by registering new handlers at runtime, and new pattern compilers can be made by instantiating the PatternEvaluator class. The grammar for the pattern language implemented in this file is: pattern ::= element | choice | list | internal_node | negated_pattern | any patterns ::= pattern, patterns | () element ::= ('?', element_name, restrictions) element_name ::= PYTHON_STRING restrictions ::= PYTHON_FUNCTION, restrictions | () choice ::= ('?:choice', patterns) list ::= ('List', patterns) internal_node ::= (pattern, neighbor_constraints) neighbor_constraints ::= input_list | output_list | input_list, output_list input_list ::= ('In', patterns) output_list ::= ('Out', patterns) negated_pattern ::= ('?:not', pattern) any ::= ('?:any',) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.util import tf_inspect from kfac.python.ops.tensormatch import tensorflow_graph_util as util def _any(itr): """Similar to Python's any, but returns the first value that matches.""" for val in itr: if val: return val return False def _all(itr): """Similar to Python's all, but returns the first value that doesn't match.""" any_iterations = False val = None for val in itr: any_iterations = True if not val: return val return val if any_iterations else True def is_seq(obj): return isinstance(obj, (tuple, list)) def is_nonempty_seq(obj): return is_seq(obj) and bool(obj) def is_empty_seq(obj): return is_seq(obj) and not bool(obj) ## define the syntax of the pattern language is_pattern = is_nonempty_seq def is_element_pattern(pat): return is_pattern(pat) and pat[0] == '?' def element_name(pat): return pat[1] def element_restrictions(pat): return pat[2:] def is_choice_pattern(pat): return is_pattern(pat) and pat[0] == '?:choice' def choice_patterns(pat): return pat[1:] def is_list_pattern(pat): return is_pattern(pat) and pat[0] == 'List' def list_patterns(pat): return pat[1:] def is_not_pattern(pat): return is_pattern(pat) and pat[0] == '?:not' def negated_pattern(pat): return pat[1] def is_any_pattern(pat): return is_pattern(pat) and pat[0] == '?:any' def is_any_noconsume_pattern(pat): return is_pattern(pat) and pat[0] == '?:any_noconsume' def is_internal_node_pattern(pat): def is_neighbor_constraints(lst): tags = tuple(item[0] for item in lst) return tags in {('In',), ('Out',), ('In', 'Out')} return (is_pattern(pat) and all(is_pattern(item) for item in pat) and is_neighbor_constraints(pat[1:])) def internal_node_pattern(pat): return pat[0] def internal_node_input_pattern(pat): for item in pat[1:]: if item[0] == 'In': return ('List',) + tuple(item[1:]) return ('?:any_noconsume',) def internal_node_output_pattern(pat): for item in pat[1:]: if item[0] == 'Out': return ('List',) + tuple(item[1:]) return ('?:any_noconsume',) def internal_patterns(pat): return [internal_node_pattern(pat), internal_node_input_pattern(pat), internal_node_output_pattern(pat)] ## constructors for pattern-matching combinators def match_eqv(pattern): def eqv_match(data, bindings, consumed, succeed): return data == pattern and succeed(bindings, consumed | {data}) return eqv_match def match_any(data, bindings, consumed, succeed): try: consumed = consumed | {data} # pylint: disable=g-no-augmented-assignment except TypeError: consumed = consumed | set(data) # pylint: disable=g-no-augmented-assignment return succeed(bindings, consumed) def match_any_noconsume(data, bindings, consumed, succeed): # pylint: disable=unused-argument # this combinator succeeds (but does not append to the consumed set) # regardless of the value of 'data', though the caller still passes 'data' # (since all combinators have the same signature) return succeed(bindings, consumed) def match_element(variable_name, restrictions): """Matches an element.""" def element_match(data, bindings, consumed, succeed): consumed = consumed | {data} # pylint: disable=g-no-augmented-assignment if _all(restriction(data) for restriction in restrictions): if not variable_name: return succeed(bindings, consumed) elif variable_name in bindings: return bindings[variable_name] == data and succeed(bindings, consumed) return succeed(dict(bindings, **{variable_name: data}), consumed) return False return element_match def match_choice(*match_combinators): def choice_match(data, bindings, consumed, succeed): return _any(matcher(data, bindings, consumed, succeed) for matcher in match_combinators) return choice_match def match_list(*match_combinators): """Matches a list.""" def list_match(data, bindings, consumed, succeed): return _list_match(data, match_combinators, bindings, consumed, succeed) def _list_match(data, matchers, bindings, consumed, succeed): """Apply matchers elementwise to a list, collecting bindings sequentially. Args: data: The list on which to apply the matcher list. matchers: The corresponding list of matchers to apply, element-by-element. bindings: The dictionary of bindings to be consistent with. consumed: The list of graph nodes consumed so far. succeed: The continuation function to call when there is a match. Returns: False if there is no match, or succeed(bindings) if there is one. """ def match_first_then_subsequent(combinator, datum): return combinator(datum, bindings, consumed, match_subsequent_elements) def match_subsequent_elements(bindings, consumed): return _list_match(data[1:], matchers[1:], bindings, consumed, succeed) if is_empty_seq(matchers) and is_empty_seq(data): return succeed(bindings, consumed) return (is_nonempty_seq(matchers) and is_nonempty_seq(data) and match_first_then_subsequent(matchers[0], data[0])) return list_match def match_not(match_combinator): def not_match(data, bindings, consumed, succeed): return (not match_combinator(data, bindings, set(), lambda bindings, _: True) and succeed(bindings, consumed)) return not_match def match_internal(*match_combinators): expanded_matcher = match_list(*match_combinators) def internal_node_match(data, bindings, consumed, succeed): try: expanded = [data, util.expand_inputs(data), util.expand_outputs(data)] except ValueError: return False return expanded_matcher(expanded, bindings, consumed, succeed) return internal_node_match ## parsing the pattern language into compositions of combinators class PatternEvaluator(object): """Pattern evaluator class.""" def __init__(self, default_operation=None): self.default_operation = default_operation self.handlers = [] def defhandler(self, predicate, handler): self.handlers.append((predicate, handler)) def __call__(self, pat): for predicate, handler in self.handlers: if predicate(pat): return handler(pat) if self.default_operation: return self.default_operation(pat) raise ValueError make_combinators = PatternEvaluator(match_eqv) make_combinators.defhandler( is_element_pattern, lambda pat: match_element(element_name(pat), element_restrictions(pat))) make_combinators.defhandler( is_list_pattern, lambda pat: match_list(*map(make_combinators, list_patterns(pat)))) make_combinators.defhandler( is_choice_pattern, lambda pat: match_choice(*map(make_combinators, choice_patterns(pat)))) make_combinators.defhandler( is_not_pattern, lambda pat: match_not(make_combinators(negated_pattern(pat)))) make_combinators.defhandler( is_any_pattern, lambda pat: match_any) make_combinators.defhandler( is_any_noconsume_pattern, lambda pat: match_any_noconsume) make_combinators.defhandler( is_internal_node_pattern, lambda pat: match_internal(*map(make_combinators, internal_patterns(pat)))) ## utility function so the patterns require fewer parentheses def expand_thunks(pat): """Expands thunks (zero-argument functions) in a pattern by calling them. Args: pat: The pattern to expand, possibly containing thunks. Returns: The expanded pattern. """ def is_thunk(x): if hasattr(x, '__call__'): spec = tf_inspect.getargspec(x) num_free_args = len(set(spec.args)) - len(set(spec.defaults or {})) return num_free_args == 0 return False while is_thunk(pat): pat = pat() if isinstance(pat, (tuple, list)): return type(pat)(map(expand_thunks, pat)) return pat ## main matcher interface functions def matcher(pattern): combinators = make_combinators(expand_thunks(pattern)) def match(node): return combinators(node, {}, set(), lambda bindings, _: bindings or True) return match def all_matcher(pattern): combinators = make_combinators(expand_thunks(pattern)) results = [] def all_matches(node): combinators(node, {}, set(), lambda bindings, _: results.append(bindings or True)) return results return all_matches def matcher_with_consumed(pattern): combinators = make_combinators(expand_thunks(pattern)) def match(node): return combinators(node, {}, set(), lambda bindings, consumed: (bindings, consumed)) return match ================================================ FILE: kfac/python/ops/tensormatch/graph_patterns.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Convenience functions for writing patterns in Python code..""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf from kfac.python.ops.tensormatch import tensorflow_graph_util as util ## patterns def Op(name=None): return ('?', name, util.is_op) def Tensor(name=None): return ('?', name, util.is_tensor) def Variable(name=None): return ('?', name, util.is_var) def Const(name=None): return ('?', name, util.is_const) def Placeholder(name=None): return ('?', name, util.is_placeholder) util.import_ops_no_clobber(globals(), dir(tf.raw_ops)) # NOTE(mattjj): renamed in TF 1.0, but not registered as an op in 1.0.1 Unstack = util.make_op_pattern('Unpack') # pylint: disable=invalid-name ## convenient compound patterns # The op definitions are pulled in via the op_def_registry, which is # why we disable the undefined variable check for e.g. Rsqrt, Mul, etc. # Otherwise we would have to refer to them by name rather than object. # pylint: disable=undefined-variable def BatchNorm(in_pattern=Tensor('in'), scale_name='scale', offset_name='offset', output_name='out'): """Pattern constructor for matching tf.nn.batch_normalization subgraphs.""" inv_pat = (Tensor('inv'), ('In', ('?:choice', Rsqrt, (Mul, ('In', (Tensor, ('In', Rsqrt)), Tensor(scale_name)))))) without_offset_pat = (Mul, ('In', Tensor, Tensor('inv'))) with_offset_pat = (Sub, ('In', Tensor(offset_name), (Tensor, ('In', (Mul, ('In', Tensor, Tensor('inv'))))))) return (Tensor(output_name), ('In', (AddV2, ('In', (Tensor, ('In', (Mul, ('In', in_pattern, inv_pat)))), (Tensor, ('In', ('?:choice', with_offset_pat, without_offset_pat))))))) def FusedBatchNormOutput(in_pattern=Tensor('in'), scale_name='scale', offset_name='offset', output_name='out'): """Pattern constructor for matching tf.nn.fused_batch_norm subgraphs.""" return (Tensor(output_name), ('In', (('?:choice', FusedBatchNorm, FusedBatchNormV2, FusedBatchNormV3), ('In', in_pattern, Tensor(scale_name), Tensor(offset_name), Tensor, Tensor)))) # TODO(mattjj): add more ops to this pattern Nonlinearity = ('?:choice', Relu, Tanh) # pylint: disable=invalid-name def ScaleAndShift(in_pattern=Tensor('in'), scale_name='scale', shift_name='shift', output_name='out'): """Pattern constructor for matching scale & shift operation subgraphs.""" scale_pat_r = (Mul, ('In', in_pattern, Variable(scale_name))) scale_pat_l = (Mul, ('In', Variable(scale_name), in_pattern)) scale_pat = ('?:choice', scale_pat_r, scale_pat_l) pat_r = (('?:choice', Add, AddV2), ('In', (Tensor, ('In', scale_pat)), Variable(shift_name))) pat_l = (('?:choice', Add, AddV2), ('In', Variable(shift_name), (Tensor, ('In', scale_pat)))) return (Tensor(output_name), ('In', ('?:choice', pat_r, pat_l, scale_pat))) def Affine(in_pattern=Tensor('in'), linear_op_name='linear_op', weights_name='weights', biases_name='biases', output_name='pre_activations'): """Pattern constructor for matching affine operation subgraphs.""" linear_pat = (('?:choice', Conv2D(linear_op_name), MatMul(linear_op_name), BatchMatMulV2(linear_op_name)), ('In', in_pattern, Variable(weights_name))) affine_pat_r = (('?:choice', Add, BiasAdd, AddV2), ('In', (Tensor, ('In', linear_pat)), Variable(biases_name))) affine_pat_l = (('?:choice', Add, BiasAdd, AddV2), ('In', Variable(biases_name), (Tensor, ('In', linear_pat)))) affine_pat = ('?:choice', affine_pat_r, affine_pat_l) return (Tensor(output_name), ('In', ('?:choice', affine_pat, linear_pat))) def Embed(in_pattern=Tensor('in'), linear_op_name='linear_op', weights_name='weights', axis_name='axis', output_name='pre_activations'): """Pattern constructor for matching embedding layer subgraphs.""" embed_v1 = (('?:choice', Gather(linear_op_name), ResourceGather(linear_op_name)), ('In', Variable(weights_name), in_pattern)) embed_v2 = (GatherV2(linear_op_name), ('In', Variable(weights_name), in_pattern, Tensor(axis_name))) embed = ('?:choice', embed_v1, embed_v2) return (Tensor(output_name), ('In', embed)) # Only used in tests: def Layer(in_pattern=Tensor('in'), **kwargs): """Pattern constructor for matching a basic layer.""" return (Tensor('activations'), ('In', (Nonlinearity, ('In', Affine( in_pattern, **kwargs))))) # Only used in tests: def LayerWithBatchNorm(in_pattern=Tensor('in')): """Pattern constructor for matching a layer with batch normalization.""" return (Tensor('final_activations'), ('In', (Nonlinearity, ('In', BatchNorm(Affine(in_pattern)))))) # pylint: enable=undefined-variable ================================================ FILE: kfac/python/ops/tensormatch/graph_search.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions for automatically registering network layers for K-FAC.""" import collections from absl import logging import enum import tensorflow.compat.v1 as tf from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import resource_variable_ops from kfac.python.ops import utils from kfac.python.ops.tensormatch import graph_matcher as gm from kfac.python.ops.tensormatch import graph_patterns as gp from kfac.python.ops.tensormatch import tensorflow_graph_util as graph_utils class RecordType(enum.Enum): fully_connected = 1 conv2d = 2 scale_and_shift = 3 batch_norm = 4 class AmbiguousRegistrationError(Exception): pass class MatchRecord(object): """An object for storing data about graph pattern matches.""" def __init__(self, record_type, params, tensor_set, data=None): """Construct a new `Record` object. Args: record_type: A `RecordType` representing the type of layer being recorded. params: A list of the variables used by this layer. tensor_set: A set of all tensors matched by the pattern. This is used for determining when one match is a subset of another. data: An optional dict for storing attributes specific to certain record types. """ self.record_type = record_type self.params = params self.tensor_set = tensor_set if data is None: data = dict() self.data = data def ensure_sequence(obj): """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" if isinstance(obj, (tuple, list)): return obj else: return (obj,) def record_affine_from_bindings(bindings, consumed_tensors, tensors_to_variables): """Construct a MatchRecord for the given Affine pattern bindings. Args: bindings: A dict representing a matched pattern. Strings representing components of the pattern are mapped to the matched Tensors. consumed_tensors: A set of all tensors consumed by the matched pattern. This should be a superset of the values of the bindings dict. tensors_to_variables: A dict mapping Tensors to the variables referencing them. Returns: A `MatchRecord` containing the information necessary to register the layer. Raises: ValueError: If the bindings contain biases but not weights. """ if 'biases' in bindings: biases = tensors_to_variables.get(bindings['biases']) else: biases = None weights = tensors_to_variables.get(bindings['weights'], None) inputs = bindings['in'] outputs = bindings['pre_activations'] linear_op = bindings['linear_op'] if biases is not None and weights is None: raise ValueError("Can't register linear layer part with only biases.") if weights is not None and biases is not None: params = (weights, biases) else: params = weights if params is not None: record_data = dict(inputs=inputs, outputs=outputs) is_sparse = (linear_op.type == 'Gather' or linear_op.type == 'GatherV2' or linear_op.type == 'ResourceGather') if (linear_op.type == 'MatMul' or linear_op.type == 'BatchMatMulV2' or is_sparse): record_type = RecordType.fully_connected if len(inputs.shape) >= 4 or (is_sparse and len(inputs.shape) >= 3): raise ValueError('K-FAC currently doesn''t support multi-use/temporal ' 'fully-connected layers with more than two batch/time ' 'dimensions. Two is the max, and they must be in the ' 'order [time, batch, ...]. Found this for params {} ' 'and op {}.'.format(params, repr(linear_op))) if ((linear_op.type == 'MatMul' and (linear_op.get_attr('transpose_a') or linear_op.get_attr('transpose_b'))) or (linear_op.type == 'BatchMatMulV2' and (linear_op.get_attr('adj_x') or linear_op.get_attr('adj_y')))): raise ValueError('K-FAC currently doesn''t support fully-connected ' 'layers with transposed inputs or weights as part of ' 'the actual op. Found this for params {} and ' 'op {}.'.format(params, repr(linear_op))) record_data['dense_inputs'] = not is_sparse elif linear_op.type == 'Conv2D': record_type = RecordType.conv2d strides = tuple(map(int, linear_op.get_attr('strides'))) padding = linear_op.get_attr('padding') data_format = linear_op.get_attr('data_format') # In Python 3 this might be class "bytes" so we convert to string. if not isinstance(padding, str): padding = padding.decode() if not isinstance(data_format, str): data_format = data_format.decode() record_data['strides'] = strides record_data['padding'] = padding record_data['data_format'] = data_format else: raise ValueError("Can't register operation: {}".format(repr(linear_op))) return MatchRecord( record_type=record_type, params=params, tensor_set=consumed_tensors, data=record_data) def record_scale_and_shift_from_bindings(bindings, consumed_tensors, tensors_to_variables): """Construct a MatchRecord for the given ScaleAndShift pattern bindings. Args: bindings: A dict representing a matched pattern. Strings representing components of the pattern are mapped to the matched Tensors. consumed_tensors: A set of all tensors consumed by the matched pattern. This should be a superset of the values of the bindings dict. tensors_to_variables: A dict mapping Tensors to the variables referencing them. Returns: A `MatchRecord` containing the information necessary to register the layer. """ if 'shift' in bindings: shift = tensors_to_variables.get(bindings['shift']) else: shift = None scale = tensors_to_variables.get(bindings['scale'], None) inputs = bindings['in'] outputs = bindings['out'] # I'm not sure if this can ever actually happen. if shift is not None and scale is None: raise ValueError("Can't register scale_and_shift with only shift.") if scale is not None and shift is not None: params = (scale, shift) else: params = scale if params is not None: record_data = dict(inputs=inputs, outputs=outputs) return MatchRecord( record_type=RecordType.scale_and_shift, params=params, tensor_set=consumed_tensors, data=record_data) def record_batch_norm_from_bindings(bindings, consumed_tensors, tensors_to_variables): """Construct a MatchRecord for the given BatchNorm pattern bindings. Args: bindings: A dict representing a matched pattern. Strings representing components of the pattern are mapped to the matched Tensors. consumed_tensors: A set of all tensors consumed by the matched pattern. This should be a superset of the values of the bindings dict. tensors_to_variables: A dict mapping Tensors to the variables referencing them. Returns: A `MatchRecord` containing the information necessary to register the layer. """ if 'offset' in bindings: offset = tensors_to_variables.get(bindings['offset']) else: offset = None if 'scale' in bindings: scale = tensors_to_variables.get(bindings['scale']) else: scale = None inputs = bindings['in'] outputs = bindings['out'] if scale is not None and offset is not None: params = (scale, offset) elif scale is not None: params = scale elif offset is not None: params = offset else: params = None if params is not None: record_data = dict(inputs=inputs, outputs=outputs) return MatchRecord( record_type=RecordType.batch_norm, params=params, tensor_set=consumed_tensors, data=record_data) def register_layers(layer_collection, varlist, batch_size=None): """Walk the graph and register all layers to layer_collection. Parameters used multiple times in the graph need to be handled differently depending on context: this could either mean the parameters represent an RNN layer, or that the graph has been replicated as multiple "towers" to allow data parallelism. We differentiate these cases by examining the loss functions registered by layer_collection: if losses have been registered multiple times with reuse=True, we separate the subgraphs corresponding to each tower and register layers independently for each with reuse=True. Args: layer_collection: A `LayerCollection` to use for registering layers. varlist: A list of the variables in the graph. batch_size: A `int` representing the batch size. Needs to specified if registering generic variables that don't match any layer patterns or if time/uses is folded. If the time/uses dimension is merged with batch then this is used to infer number of uses/time-steps. NOTE: In the replicated context this must be the per-replica batch size, and not the total batch size. Returns: A `dict` of the entries registered to layer_collection.fisher_blocks. Raises: ValueError: If not all losses were registered the same number of times. If any variables specified as part of linked groups were not matched with their group. If the same variable is used in multiple layers types (e.g. fully connected and 2d convolution), or if the same variable is used in multiple layers of a type that doesn't support shared parameters. AmbiguousRegistrationError: If any variables must be registered as generic and batch_size is not specified, or if even after filtering, there are matches with overlapping but unequal sets of variables (see filter_records). """ original_fisher_blocks = layer_collection.fisher_blocks.copy() user_registered_variables = set() for params in layer_collection.fisher_blocks.keys(): for variable in ensure_sequence(params): user_registered_variables.add(variable) user_registered_variables = frozenset(user_registered_variables) if not layer_collection.losses: raise ValueError('No registered losses found. Automatic registration ' 'requires all losses in the graph to be registered before ' 'it can begin.') else: inputs_by_loss = tuple(tuple(loss.inputs for loss in loss_list) for loss_list in layer_collection.towers_by_loss) num_towers = len(inputs_by_loss[0]) if not all( (len(input_tensors) == num_towers for input_tensors in inputs_by_loss)): raise ValueError( 'If losses are registered with reuse=True, each name must be ' 'registered the same number of times.') for tower_number, tower_input_tensors in enumerate(zip(*inputs_by_loss)): reuse = (tower_number > 0) with tf.variable_scope('tower_%d' % tower_number, reuse=reuse): subgraph = utils.SubGraph(tower_input_tensors) register_subgraph_layers( layer_collection, varlist, subgraph, user_registered_variables=user_registered_variables, reuse=reuse, batch_size=batch_size) fisher_blocks = layer_collection.fisher_blocks return { params: fisher_blocks[params] for params in set(fisher_blocks) - set(original_fisher_blocks) } def register_subgraph_layers(layer_collection, varlist, subgraph, user_registered_variables=frozenset(), reuse=False, batch_size=None): """Walk a subgraph and register all layers to layer_collection. Args: layer_collection: A `LayerCollection` to use for registering layers. varlist: A list of the variables in the graph. subgraph: The `SubGraph` to search. user_registered_variables: A set of all the variables the user has manually registered. No layers using any of these variables should be registered. reuse: (OPTIONAL) bool. If True, then `layer_collection` selects a previously registered block with the same key as the key derived from `params` of that block. If False, a new block is registered. batch_size: A `int` representing the batch size. Needs to specified if registering generic variables that don't match any layer patterns or if the time/uses dimension is folded into batch. If the time/uses dimension is merged with batch then this is used to infer number of uses/time-steps. Raises: ValueError: If any variables specified as part of linked groups were not matched with their group. If the same variable is used in multiple layers types (e.g. fully connected and 2d convolution), or if the same variable is used in multiple layers of a type that doesn't support shared parameters. AmbiguousRegistrationError: If any variables must be registered as generic and batch_size is not specified, or if even after filtering, there are matches with overlapping but unequal sets of variables (see filter_records). """ # List of patterns and binding functions to use when we match one of them match_register_list = [(gm.matcher_with_consumed(gp.Affine), record_affine_from_bindings), (gm.matcher_with_consumed(gp.ScaleAndShift), record_scale_and_shift_from_bindings), (gm.matcher_with_consumed(gp.BatchNorm), record_batch_norm_from_bindings), (gm.matcher_with_consumed(gp.FusedBatchNormOutput), record_batch_norm_from_bindings), (gm.matcher_with_consumed(gp.Embed), record_affine_from_bindings)] # Patterns return bindings to raw tensors, so we need to be able to map back # to variables from the tensors those variables reference. def var_to_tensors(var): if resource_variable_ops.is_resource_variable(var): if tf.control_flow_v2_enabled() and hasattr(layer_collection.graph, 'captures'): # TODO(b/143690035): Note that the "captures" property relies on an # API which might change. captures = layer_collection.graph.captures return [h for vh, h in captures if vh is var.handle] else: return [var.handle] if utils.is_reference_variable(var): return [tf_ops.internal_convert_to_tensor(var, as_ref=True)] raise ValueError('%s is not a recognized variable type.' % str(var)) tensors_to_variables = {tensor: var for var in varlist for tensor in var_to_tensors(var)} # Get all the ops from the graph. ops = layer_collection.graph.get_operations() # Filter out tf.identity ops since otherwise the matcher generates spurious # matches. ops = tuple(op for op in ops if not graph_utils.is_identity(op)) # Extract out the output tensors from the ops tensors = tuple(out for op in ops for out in op.outputs) # Filter the tensors to include only those in the subgraph. tensors = subgraph.filter_list(tensors) # Go through each tensor and try to match each pattern to it. record_list_dict = dict() for tensor in tensors: for match, recfunc in match_register_list: match_res = match(tensor) if match_res: bindings, consumed_tensors = match_res record = recfunc(bindings, consumed_tensors, tensors_to_variables) if record is not None: if record.params not in record_list_dict: record_list_dict[record.params] = [] record_list_dict[record.params].append(record) # Filter out records violating any rules. record_list_dict = filter_records(layer_collection, record_list_dict, user_registered_variables) # Register the layers by going through the lists of records for each param. register_records(layer_collection, record_list_dict, reuse, batch_size) # Determine which variables were registered either by the user or # in the current call to register_subgraph_layers. automatically_registered_variables = { var for params in record_list_dict for var in ensure_sequence(params) } registered_variables = ( automatically_registered_variables | user_registered_variables) # Register any remaining parameters generically. for variable in varlist: if variable not in registered_variables: for specified_grouping in layer_collection.linked_parameters: assert isinstance(specified_grouping, frozenset) if variable in specified_grouping and len(specified_grouping) > 1: raise ValueError( 'Variable {} in linked group {} was not matched.'.format( variable, specified_grouping)) generic_bad_string = ('generic registrations may be a symptom that the ' 'scanner is failing to auto-detect your model. ' 'Generic uses a last-resort approximation, and ' 'should never be used for common layer types that ' 'K-FAC properly supports, such as convs or ' 'fully-connected layers.') if batch_size is None: raise AmbiguousRegistrationError( ('Tried to register {} as generic without knowledge of batch_size. ' 'You can pass batch_size in to fix this error. But please note, ' + generic_bad_string).format(variable)) logging.warning(('Registering {} as generic because graph scanner ' 'couldn\'t match a pattern for it. This can sometimes ' 'be caused by the variable not being present in the ' 'graph terminating at the registered losses. You might ' 'need to pass an explicit list of parameters to tell ' 'the system what parameters are actually in your model. ' 'Note that ' + generic_bad_string).format(variable)) layer_collection.register_generic(variable, batch_size, reuse=reuse) def filter_user_registered_records(record_list_dict, user_registered_variables): """Remove any matches that contain a variable registered by the user.""" record_list_dict = record_list_dict.copy() for params in list(record_list_dict.keys()): for variable in ensure_sequence(params): if variable in user_registered_variables: del record_list_dict[params] break return record_list_dict def filter_grouped_variable_records(layer_collection, record_list_dict): """Remove any matches violating user specified parameter groupings.""" record_list_dict = record_list_dict.copy() for params in list(record_list_dict.keys()): for specified_grouping in layer_collection.linked_parameters: param_set = set(ensure_sequence(params)) assert isinstance(specified_grouping, frozenset) if (param_set.intersection(specified_grouping) and param_set != specified_grouping): del record_list_dict[params] break return record_list_dict def filter_subgraph_records(record_list_dict): """Remove any matches that correspond to strict subgraphs of other matches.""" # Flatten the records dict to compare records with different parameters. flat_record_list = [ record for records in record_list_dict.values() for record in records ] # Compare all pairs of records that share any variables. We perform two # passes, first marking variables for deletion by adding them to a set and # then removing all marked variables, in order to avoid traversing # flat_record_list on every removal while still maintaining record order. records_by_variable = collections.defaultdict(list) for record in flat_record_list: for variable in ensure_sequence(record.params): records_by_variable[variable].append(record) records_to_remove = set() for record in flat_record_list: for variable in ensure_sequence(record.params): for other_record in records_by_variable[variable]: if record.tensor_set < other_record.tensor_set: records_to_remove.add(record) flat_record_list = [ record for record in flat_record_list if record not in records_to_remove ] # Unflatten the records list. record_list_dict = collections.defaultdict(list) for record in flat_record_list: record_list_dict[record.params].append(record) assert record is not None return dict(record_list_dict) def filter_records(layer_collection, record_list_dict, user_registered_variables): """Filter out recorded matches based on a set of rules. A match should be filtered out if any of the following are true: 1. It contains any variables already registered by the user. 2. It violates the user specified variable groupings. 3. It corresponds to a strict subgraph of another match not already filtered out by the above steps. Args: layer_collection: A `LayerCollection` to use for registering layers. record_list_dict: A dict mapping tuples of variables to lists of `MatchRecord`s representing all of the places those variables are used in the graph. user_registered_variables: A set of all the variables the user has manually registered. No layers using any of these variables should be registered. Returns: A copy of `record_list_dict` with the records violating rules filtered out. Raises: AmbiguousRegistrationError: If even after filtering, there are matches with overlapping but unequal sets of variables. In these cases, the user will need to either manually register layers that use these variables, or specify a preferred variable grouping. """ record_list_dict = filter_user_registered_records(record_list_dict, user_registered_variables) record_list_dict = filter_grouped_variable_records(layer_collection, record_list_dict) record_list_dict = filter_subgraph_records(record_list_dict) # Look for any violation in the consistency of the remaining matches. recorded_params = dict() ambiguous_registration_errors = [] for params in record_list_dict: for variable in ensure_sequence(params): if variable in recorded_params: ambiguous_registration_errors.append( 'Variable {} was recorded in multiple groups: {} and {}.'.format( variable, params, recorded_params[variable])) else: recorded_params[variable] = params if ambiguous_registration_errors: raise AmbiguousRegistrationError('\n'.join(ambiguous_registration_errors)) return record_list_dict def register_records(layer_collection, record_list_dict, reuse=False, batch_size=None): """Registers the given records to layer_collection. Args: layer_collection: A `LayerCollection` to use for registering layers. record_list_dict: A dict mapping tuples of variables to lists of `MatchRecord`s representing all of the places those variables are used in the graph. reuse: (OPTIONAL) bool. If True, then `layer_collection` selects a previously registered block with the same key as the key derived from `params` of that block. If False, a new block is registered. batch_size: A `int` representing the batch size. Needs to specified if registering generic variables that don't match any layer patterns or if time/uses is folded. If the time/uses dimension is merged with batch then this is used to infer number of uses/time-steps. Raises: ValueError: If record_list_dict contains multiple record types for a single set of variables, or if there are multiple records for a set of variables of a type that doesn't support shared parameters. AmbiguousRegistrationError: If a batch norm layer registration is required but batch_size is not passed. """ mixed_record_type_errors = [] # TODO(b/69627702): Layers must be registered in a deterministic order, else # FisherFactors may end up with different variable names. params_list = sorted(record_list_dict.keys(), key=str) for params in params_list: record_list = record_list_dict[params] # We don't support mixed types for the same params and probably never # will. if not all(record_list[0].record_type == record.record_type for record in record_list): mixed_record_type_errors.append( 'Detected variables {} with mixed record types: {}.'.format( params, record_list)) continue record_type = record_list[0].record_type if record_type is RecordType.fully_connected: dense_inputs = record_list[0].data['dense_inputs'] if (not dense_inputs and layer_collection._get_linked_approx(params) is None): # pylint: disable=protected-access # Nothing is lost by using a diagonal approx for the input factor here. # This is because the 2nd-moment matrix for 1-hot vectors will be # naturally diagonal. approx = 'kron_indep_in_diag' else: approx = None if len(record_list) > 1: logging.info( 'Registering as multi-use fully-connected: {}'.format(params)) inputs = tuple(record.data['inputs'] for record in record_list) outputs = tuple(record.data['outputs'] for record in record_list) layer_collection.register_fully_connected_multi( params, inputs, outputs, reuse=reuse, dense_inputs=dense_inputs, approx=approx) else: if dense_inputs: folded_dim_limit = 2 else: folded_dim_limit = 1 record = record_list[0] inputs = record.data['inputs'] outputs = record.data['outputs'] first_dim = inputs.shape.as_list()[0] num_dim = len(inputs.shape) is_batch_time_folded = not ( batch_size is None or first_dim is None or first_dim == batch_size or num_dim > folded_dim_limit) if is_batch_time_folded or num_dim > folded_dim_limit: logging.info( 'Registering as multi-use fully-connected: {}'.format(params)) logging.warning('Registering {} as multi-use fully-connected layer ' 'with folded batch and time/use dimension. If using ' 'the non-independent K-FAC RNNs approximations (' '"Option 1" or "Option 2") make sure that the ' 'dimensions are ordered [time/use, batch] before ' 'folding, and not the other way around. Otherwise ' 'you will get a silent failure of the method!' ''.format(params)) if is_batch_time_folded: if first_dim % batch_size != 0: raise ValueError('Passed batch_size did not divide first ' 'dimension of tensor with presumed folded ' 'batch and use/times dimension. Possible causes ' 'include passing the wrong batch size (e.g. ' 'passing overall instead of per-replica), or a ' 'non-standard layer (possibly with no batch ' 'dependency). Layer params are: ' '{}. Input and output tensors are: {} and {}' ''.format(params, inputs, outputs)) num_uses = first_dim // batch_size else: num_uses = record_list[0].data['inputs'].shape.as_list()[1] layer_collection.register_fully_connected_multi( params, inputs, outputs, num_uses=num_uses, reuse=reuse, dense_inputs=dense_inputs, approx=approx) else: logging.info('Registering as fully-connected: {}'.format(params)) layer_collection.register_fully_connected( params, inputs, outputs, reuse=reuse, dense_inputs=dense_inputs, approx=approx) elif record_type is RecordType.conv2d: if len(record_list) > 1: logging.info('Registering as multi-use conv2d: {}'.format(params)) inputs = tuple(record.data['inputs'] for record in record_list) outputs = tuple(record.data['outputs'] for record in record_list) strides = record_list[0].data['strides'] padding = record_list[0].data['padding'] data_format = record_list[0].data['data_format'] layer_collection.register_conv2d_multi( params, strides, padding, inputs, outputs, data_format=data_format, reuse=reuse) else: record = record_list[0] inputs = record.data['inputs'] outputs = record.data['outputs'] strides = record.data['strides'] padding = record.data['padding'] data_format = record.data['data_format'] first_dim = inputs.shape.as_list()[0] num_dim = len(inputs.shape) is_batch_time_folded = not ( batch_size is None or first_dim is None or first_dim == batch_size or num_dim > 4) if is_batch_time_folded or num_dim > 4: logging.info('Registering as multi-use conv2d: {}'.format(params)) if is_batch_time_folded: if first_dim % batch_size != 0: raise ValueError('Passed batch_size did not divide first ' 'dimension of tensor with presumed folded ' 'batch and use/times dimension. Possible causes ' 'include passing the wrong batch size (e.g. ' 'passing overall instead of per-replica), or a ' 'non-standard layer (possibly with no batch ' 'dependency). Layer params are: ' '{}. Input and output tensors are: {} and {}' ''.format(params, inputs, outputs)) num_uses = first_dim // batch_size else: raise ValueError('Currently not supporting conv layers with ' 'separate time/uses dim.') layer_collection.register_conv2d_multi( params, strides, padding, inputs, outputs, data_format=data_format, num_uses=num_uses, reuse=reuse) else: logging.info('Registering as conv2d: {}'.format(params)) layer_collection.register_conv2d(params, strides, padding, inputs, outputs, data_format=data_format, reuse=reuse) elif record_type is RecordType.scale_and_shift: logging.info('Registering as scale (& shift): {}'.format(params)) if len(record_list) > 1: raise ValueError('Multi-use registrations currently not supported for ' 'scale & shift operations.') record = record_list[0] inputs = record.data['inputs'] outputs = record.data['outputs'] layer_collection.register_scale_and_shift(params, inputs, outputs, reuse=reuse) elif record_type is RecordType.batch_norm: # For now we register this as generic instead of scale_and_shift because # the fused version of batch norm won't give us the quantities we need # for the latter. Could consider splitting this into fused and non-fused # cases. logging.info('Registering as generic (batch norm): {}'.format(params)) if batch_size is None: raise AmbiguousRegistrationError( 'Tried to register a batch norm layer (as generic) without ' 'knowledge of batch_size. You can pass batch_size in to fix this ' 'error.') # This is a slight hack. Ideally register_generic would work with lists # of params like it used to before we switched to the "unflattened" cov # representation so we wouldn't need to detect the approximation type. will_use_diag = ( layer_collection._get_linked_approx(params) == 'diagonal' # pylint: disable=protected-access or (layer_collection.default_generic_approximation == 'diagonal' and layer_collection._get_linked_approx(params) is None) # pylint: disable=protected-access ) if will_use_diag: for param in ensure_sequence(params): layer_collection.register_generic(param, batch_size, reuse=reuse) else: layer_collection.register_generic(params, batch_size, reuse=reuse) else: assert False, 'Invalid record type {}'.format(record_type) if mixed_record_type_errors: raise ValueError('\n'.join(mixed_record_type_errors)) ================================================ FILE: kfac/python/ops/tensormatch/tensorflow_graph_util.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Abstraction layer for working with the TensorFlow graph model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import six import tensorflow.compat.v1 as tf from tensorflow.python.ops import resource_variable_ops from kfac.python.ops import utils # pylint: disable=g-import-not-at-top try: from tensorflow.python.types import core except ModuleNotFoundError: from tensorflow.python.framework import ops as tf_ops # pylint: enable=g-import-not-at-top def is_op(node): return isinstance(node, tf.Operation) def is_tensor(node): try: # TODO(b/154650521): Use tf.Tensor instead of core.Tensor. return isinstance(node, core.Tensor) except NameError: return tf_ops.is_dense_tensor_like(node) def is_var(node): if not is_tensor(node): return False if node.op.type.startswith('Variable'): return True if ((resource_variable_ops.is_resource_variable(node) or utils.is_reference_variable(node))): return True # TODO(b/143690035): Note that the Placeholder type handles the Control Flow # V2 case, but this could stop working in the future if the implementation of # Control Flow V2 changes. if node.dtype == tf.resource and (node.op.type == 'VarHandleOp' or node.op.type == 'Placeholder'): return True return False def is_const(node): return is_tensor(node) and node.op.type == 'Const' def is_placeholder(node): return is_tensor(node) and node.op.type == 'Placeholder' def is_leaf(node): return is_var(node) or is_const(node) or is_placeholder(node) def is_identity(node): if not is_op(node): return False # For ResourceVariables, a 'ReadVariableOp' has a single 'Enter' input, which # in turn has a Tensor with dtype == resource as input. return (node.type in {'Identity', 'ReadVariableOp', 'Enter', 'IdentityN'} or 'convert_gradient_to_tensor' in node.type) def op_type_is(typename): def is_op_with_typename(node): return is_op(node) and node.type == typename return is_op_with_typename def reduce_identity_ops(node): while is_tensor(node) and is_identity(node.op): # IdentityN is sometimes used when custom gradients are involved. Its # two inputs should be the same in that case. Otherwise there should only # be one input. assert (len(node.op.inputs) == 1 or (node.op.type == 'IdentityN' and node.op.inputs[0] == node.op.inputs[1])) node = node.op.inputs[0] return node def expand_inputs(node): """Return a list of input nodes for a given TF graph node (or node list).""" if is_op(node): return [reduce_identity_ops(tensor) for tensor in node.inputs[:]] elif is_tensor(node) and not is_leaf(node): return [reduce_identity_ops(node).op] elif isinstance(node, list) and all(is_tensor(elt) for elt in node): ops = {reduce_identity_ops(tensor).op for tensor in node} if len(ops) == 1: return [ops.pop()] raise ValueError return None def expand_outputs(node): """Return a list of output nodes for a given TF graph node.""" if is_op(node): return node.outputs[:] elif isinstance(node, tf.Variable): return node.value().consumers() elif is_tensor(node): return node.consumers() return None def make_op_pattern(typename): """Makes a pattern that matches a given Op type.""" def op_fun(name=None): return ('?', name, op_type_is(typename)) op_fun_name = typename.encode('ascii', 'ignore') # In Python 3, str.encode() produces a bytes object. Convert this to an ASCII # str. if six.PY3: op_fun_name = op_fun_name.decode('ascii') op_fun.__name__ = op_fun_name return op_fun def import_ops_no_clobber(dct, op_names): for name in op_names: if name not in dct: dct[name] = make_op_pattern(name) ================================================ FILE: kfac/python/ops/utils.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility functions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections # Dependency imports import numpy as np import tensorflow.compat.v1 as tf from tensorflow.python.tpu import tpu_function from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" POSDEF_EIG_METHOD = "self_adjoint" _TF_REPLICATOR = None def smart_assign(variable, value, assign_fn=tf.assign, force_cast=False, force_sync=True): """Calls assign_fn on variable and value in a cross-replica context. When this function is called in a per-replica context, it will enter a cross- replica context before calling assign_fn(variable, value). During training with a tf.distribute.Strategy, optimizer.minimize is always called in a per- replica context (e.g. via experimental_run for TPUStrategy). Since with this function we assign a synchronized Tensor to a MirroredVariable with assign_fn, we use a merge_call to enter a cross-replica context, then use distribution.extended.update to assign value to variable with assign_fn. When this function is called in a cross-replica context or outside of a tf.distribute.Strategy scope, smart_assign will use assign_fn as is. Operations that happen inside of a tf.distribute.Strategy scope are typically in a cross replica context, unless, for example, they happen in an experimental_run call or a call_for_each_replica call. In a cross-replica context, tf.distribute.get_replica_context() returns None. Args: variable: TF Variable. A MirroredVariable when in a distribution strategy. value: TF Tensor. This function will throw an error if value is a PerReplica type, which means it is an unsynchronized Tensor. You must reduce it using all_sum or all_average before using this function. assign_fn: assign_fn(variable, value) -> tf.Operation. The function used to update variable with value, typically tf.assign, tf.assign_add, or tf.assign_sub. force_cast: Boolean. If True we cast the `value` to the dtype of `variable` when they don't match. (Default: False) force_sync: Boolean. If True and using MirroredStrategy in a replica context, take the mean of value over all replicas to force the value to be syncronized before performing the assignment. Returns: tf.Tensor that contains the result of assign_fn(variable, value) called in a cross-replica context. """ if force_cast and variable.dtype != value.dtype: value = tf.cast(value, dtype=variable.dtype) if not (tf.distribute.has_strategy() and tf.distribute.get_replica_context()): return assign_fn(variable, value) def merge_fn(distribution, variable, value): strategy = tf.distribute.get_strategy() if isinstance(strategy, tf.distribute.MirroredStrategy) and force_sync: value = strategy.reduce(tf.distribute.ReduceOp.MEAN, value) return distribution.extended.update(variable, assign_fn, args=(value,)) return tf.distribute.get_replica_context().merge_call( merge_fn, args=(variable, value)) def smart_cond(predicate, true_fn, false_fn, name=None): """Creates ops for conditionally executing one of two functions. If MirroredStrategy is not used or outside of a MirroredStrategy replica context, this is identical to tf.cond. tf.cond does not support using functions which involve synchronization calls inside a MirroredStrategy replica context. Instead, work around this by safely evaluating the conditional across replicas and then evaluate either true_fn or false_fn back in a replica context. Note: this is only required if true_fn and/or false_fn involve a synchronization across replicas (e.g. via a reduction to evaluate the cross-replica mean). Limitations: with MirroredStrategy, true_fn and false_fn are executed via control_dependencies are a constant tensor is returned instead of the actual return values of true_fn and false_fn. This is due to the requirement that functions executed using DistributionStrategy.call_for_each_replica return a tensor rather than an operation. Args: predicate: boolean operation which determines whether to execute true_fn or false_fn. true_fn: function to execute if predicate is true. false_fn: function to execute if predicate is false. name: name to assign to the tf.cond operation. Returns: If not using MirroredStrategy or outside of a MirroredStrategy replica context, the result from true_fn or false_fn, and otherwise a constant tensor. """ if (tf.distribute.has_strategy() and tf.distribute.get_replica_context()): strategy = tf.distribute.get_strategy() else: strategy = None if not isinstance(strategy, tf.distribute.MirroredStrategy): return tf.cond(predicate, true_fn, false_fn, name) else: # Conditionals with functions which execute synchronization calls are not # well supported with Distribution Strategy. Instead follow the scheme # suggested in https://github.com/tensorflow/tensorflow/issues/27716: # 1. Execute the conditional in a cross-replica context. # 2. The conditional functions then return to a replica-context before # executing the original conditional functions. def true_fn_per_replica(): # call_for_each_replica requires a tensor to be returned. This is not true # for all functions (which, e.g., might return an op or tf.group) so # instead execute the ops as control dependency and return a constant # tensor. with tf.control_dependencies([true_fn()]): return tf.constant(0.0) def true_fn_cross_replica(): strategy = tf.distribute.get_strategy() return strategy.extended.call_for_each_replica(true_fn_per_replica) def false_fn_per_replica(): with tf.control_dependencies([false_fn()]): return tf.constant(0.0) def false_fn_cross_replica(): strategy = tf.distribute.get_strategy() return strategy.extended.call_for_each_replica(false_fn_per_replica) def cond(distribution): del distribution return tf.cond(predicate, true_fn_cross_replica, false_fn_cross_replica, name) return tf.distribute.get_replica_context().merge_call(cond) def set_global_constants(posdef_inv_method=None, tf_replicator=None): """Sets various global constants used by the classes in this module.""" global POSDEF_INV_METHOD global _TF_REPLICATOR if posdef_inv_method is not None: POSDEF_INV_METHOD = posdef_inv_method if tf_replicator is not None: _TF_REPLICATOR = tf_replicator class SequenceDict(object): """A dict convenience wrapper that allows getting/setting with sequences.""" def __init__(self, iterable=None): self._dict = dict(iterable or []) def __getitem__(self, key_or_keys): if isinstance(key_or_keys, (tuple, list)): return list(map(self.__getitem__, key_or_keys)) else: return self._dict[key_or_keys] def __setitem__(self, key_or_keys, val_or_vals): if isinstance(key_or_keys, (tuple, list)): for key, value in zip(key_or_keys, val_or_vals): self[key] = value else: self._dict[key_or_keys] = val_or_vals def items(self): return list(self._dict.items()) def tensors_to_column(tensors): """Converts a tensor or list of tensors to a column vector. Args: tensors: A tensor or list of tensors. Returns: The tensors reshaped into vectors and stacked on top of each other. """ if isinstance(tensors, (tuple, list)): return tf.concat( tuple(tf.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) else: return tf.reshape(tensors, [-1, 1]) def column_to_tensors(tensors_template, colvec): """Converts a column vector back to the shape of the given template. Args: tensors_template: A tensor or list of tensors. colvec: A 2d column vector with the same shape as the value of tensors_to_column(tensors_template). Returns: X, where X is tensor or list of tensors with the properties: 1) tensors_to_column(X) = colvec 2) X (or its elements) have the same shape as tensors_template (or its elements) """ if isinstance(tensors_template, (tuple, list)): offset = 0 tensors = [] for tensor_template in tensors_template: sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) tensor = tf.reshape(colvec[offset:(offset + sz)], tensor_template.shape) tensors.append(tensor) offset += sz tensors = tuple(tensors) else: tensors = tf.reshape(colvec, tensors_template.shape) return tensors def kronecker_product(mat1, mat2): """Computes the Kronecker product two matrices.""" m1, n1 = mat1.get_shape().as_list() mat1_rsh = tf.reshape(mat1, [m1, 1, n1, 1]) m2, n2 = mat2.get_shape().as_list() mat2_rsh = tf.reshape(mat2, [1, m2, 1, n2]) return tf.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) def layer_params_to_mat2d(vector): """Converts a vector shaped like layer parameters to a 2D matrix. In particular, we reshape the weights/filter component of the vector to be 2D, flattening all leading (input) dimensions. If there is a bias component, we concatenate it to the reshaped weights/filter component. Args: vector: A Tensor or pair of Tensors shaped like layer parameters. Returns: A 2D Tensor with the same coefficients and the same output dimension. """ if isinstance(vector, (tuple, list)): w_part, b_part = vector w_part_reshaped = tf.reshape(w_part, [-1, w_part.shape.as_list()[-1]]) return tf.concat((w_part_reshaped, tf.reshape(b_part, [1, -1])), axis=0) elif isinstance(vector, tf.IndexedSlices): return vector else: # Tensor or Tensor-like. return tf.reshape(vector, [-1, vector.shape.as_list()[-1]]) def mat2d_to_layer_params(vector_template, mat2d): """Converts a canonical 2D matrix representation back to a vector. Args: vector_template: A Tensor or pair of Tensors shaped like layer parameters. mat2d: A 2D Tensor with the same shape as the value of layer_params_to_mat2d(vector_template). Returns: A Tensor or pair of Tensors with the same coefficients as mat2d and the same shape as vector_template. """ if isinstance(vector_template, (tuple, list)): w_part, b_part = mat2d[:-1], mat2d[-1] return tf.reshape(w_part, vector_template[0].shape), b_part elif isinstance(vector_template, tf.IndexedSlices): if not isinstance(mat2d, tf.IndexedSlices): raise TypeError( "If vector_template is an IndexedSlices, so should mat2d.") return mat2d else: return tf.reshape(mat2d, vector_template.shape) def posdef_inv(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = tf.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = tf.cast(damping, dtype=tensor.dtype) return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) def posdef_inv_matrix_inverse(tensor, identity, damping): """Computes inverse(tensor + damping * identity) directly.""" return tf.matrix_inverse(tensor + damping * identity) def posdef_inv_cholesky(tensor, identity, damping): """Computes inverse(tensor + damping * identity) with Cholesky.""" chol = tf.linalg.cholesky(tensor + damping * identity) return tf.linalg.cholesky_solve(chol, identity) def posdef_inv_eig(tensor, identity, damping): """Computes inverse(tensor + damping * identity) with eigendecomposition.""" eigenvalues, eigenvectors = tf.self_adjoint_eig(tensor + damping * identity) return tf.matmul(eigenvectors / eigenvalues, eigenvectors, transpose_b=True) posdef_inv_functions = { "matrix_inverse": posdef_inv_matrix_inverse, "cholesky": posdef_inv_cholesky, "eig": posdef_inv_eig, } def posdef_eig(mat): """Computes the eigendecomposition of a positive semidefinite matrix.""" return posdef_eig_functions[POSDEF_EIG_METHOD](mat) def posdef_eig_svd(mat): """Computes the singular values and left singular vectors of a matrix.""" evals, evecs, _ = tf.svd(mat) return evals, evecs def posdef_eig_self_adjoint(mat): """Computes eigendecomposition using self_adjoint_eig.""" evals, evecs = tf.self_adjoint_eig(mat) evals = tf.abs(evals) # Should be equivalent to svd approach. return evals, evecs posdef_eig_functions = { "self_adjoint": posdef_eig_self_adjoint, "svd": posdef_eig_svd, } def cholesky(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = tf.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = tf.cast(damping, dtype=tensor.dtype) return tf.linalg.cholesky(tensor + damping * identity) class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ def __init__(self, outputs): # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() self._iter_add(outputs) self._graph = outputs[0].graph def _iter_add(self, root): """Iteratively adds all of nodes' ancestors using depth first search.""" stack = [root] while stack: nodes = stack.pop() for node in nodes: if node in self._members: continue self._members.add(node) if isinstance(node, tf.Tensor): stack.append((node.op,)) elif isinstance(node, tf.Operation): stack.append(node.inputs) def is_member(self, node): """Check if 'node' is in this subgraph.""" return node in self._members def variable_uses(self, var): """Computes number of times a variable is used. Args: var: Variable or ResourceVariable instance. Returns: Number of times a variable is used within this subgraph. Raises: ValueError: If 'var' is not a variable type. """ def _add_tensor_consumers_to_set(tensor, consumers_set): """Finds consumers of a tensor and add them to the current consumers set. """ for consumer in set(tensor.consumers()): # These are the type of ops which relay a tensor to other ops without # doing anything to the tensor value, so recursively find the actual # consumers. if consumer.type in [ "Identity", "ReadVariableOp", "Enter", "ResourceGather"]: for output in consumer.outputs: _add_tensor_consumers_to_set(output, consumers_set) else: consumers_set.add(consumer) consumers = set() if resource_variable_ops.is_resource_variable(var): if tf.control_flow_v2_enabled() and hasattr(self._graph, "captures"): # TODO(b/143690035): Note that the "captures" property relies on an API # which might change. captures = self._graph.captures for handle in [h for vh, h in captures if vh is var.handle]: _add_tensor_consumers_to_set(handle, consumers) else: _add_tensor_consumers_to_set(var.handle, consumers) elif is_reference_variable(var): _add_tensor_consumers_to_set(var.value(), consumers) else: raise ValueError("%s does not appear to be a variable." % str(var)) return len(self._members.intersection(consumers)) def filter_list(self, node_list): """Filters 'node_list' to nodes in this subgraph.""" filtered_list = [] for node in node_list: if self.is_member(node): filtered_list.append(node) return filtered_list def preferred_int_dtype(): # tf.int32 doesn't work properly on GPUs, and tf.int64 isn't recommended on # TPUs. Hence this function. if is_tpu_replicated(): return tf.int32 else: return tf.int64 def generate_random_signs(shape, dtype=tf.float32): """Generate a random tensor with {-1, +1} entries.""" ints = tf.random_uniform(shape, maxval=2, dtype=preferred_int_dtype()) return 2 * tf.cast(ints, dtype=dtype) - 1 # MirroredVariables do not have a hashable op property, which means they cannot # be used with stop_gradients. This was fixed in the TF-Nightly release, but is # not in any stable release, so we use the below hack so our fwd_gradients # function works in the TF 1.14 stable release. # TODO(b/139376871): Remove this workaround once the bugfix is in a stable release. DistributedVarOp = collections.namedtuple( "DistributedVarOp", ["name", "graph", "traceback", "type"]) class MirroredVariableWrapper(object): def __init__(self, var): self.__var = var def __getattr__(self, name): if name == 'op': return DistributedVarOp( self.__var.op.name, self.__var.op.graph, # In the updated TF codebase, convert_stack returns tuple instead of # list, which makes op.traceback hashable. tuple(self.__var.op.traceback), self.__var.op.type) else: return getattr(self.__var, name) def _as_list(x): return x if isinstance(x, (list, tuple)) else [x] def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None, colocate_gradients_with_ops=True): """Compute forward-mode gradients.""" # See b/37888268. # This version of forward-mode autodiff is based on code by Tim Cooijmans # and handles list arguments and certain special cases such as when the # ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are # generated by the first tf.gradients call. ys = _as_list(ys) xs = _as_list(xs) us = [tf.zeros_like(y) + float("nan") for y in ys] if tf.distribute.has_strategy(): stop_gradients = [MirroredVariableWrapper(v) for v in stop_gradients] dydxs = tf.gradients(ys, xs, grad_ys=us, stop_gradients=stop_gradients, colocate_gradients_with_ops=colocate_gradients_with_ops) # Deal with strange types that tf.gradients returns but can't # deal with. dydxs = [ tf.convert_to_tensor(dydx) if isinstance(dydx, tf.IndexedSlices) else dydx for dydx in dydxs ] dydxs = [ tf.zeros_like(x) if dydx is None else dydx for x, dydx in zip(xs, dydxs) ] dysdx = tf.gradients(dydxs, us, grad_ys=grad_xs, colocate_gradients_with_ops=colocate_gradients_with_ops) return dysdx def get_tf_replicator(): return _TF_REPLICATOR def is_tpu_replicated(): is_tpu_strategy = (tf.distribute.has_strategy() and tf.distribute.get_replica_context() and isinstance(tf.distribute.get_strategy(), tf.distribute.experimental.TPUStrategy)) num_shards = tpu_function.get_tpu_context().number_of_shards return is_tpu_strategy or num_shards is not None def is_replicated(): """Check if we are operating in a supported replicated context.""" if tf.distribute.has_strategy() and tf.distribute.get_replica_context(): return tf.distribute.get_strategy().num_replicas_in_sync > 1 return get_tf_replicator() is not None or is_tpu_replicated() def get_num_replicas(): """Returns the number of replicas. If not operating in a supported replicated context this function will return 1. """ tf_replicator = get_tf_replicator() if tf_replicator: return tf_replicator.num_replicas_in_sync elif tf.distribute.has_strategy(): return tf.distribute.get_strategy().num_replicas_in_sync else: # I'm assuming replicas and shards are always equal until someone tells me # different. num_replicas = tpu_function.get_tpu_context().number_of_shards if num_replicas: return num_replicas else: return 1 def get_replica_id(): """Returns an id number for the current replica, counting from 0. If not operating in a supported replicated context this function will return 0. """ tf_replicator = get_tf_replicator() if tf_replicator: return tf_replicator.current_replica_id elif tf.distribute.has_strategy() and tf.distribute.get_replica_context(): return tf.distribute.get_replica_context().replica_id_in_sync_group # This code below this point is based on # TensorTracer._add_replica_id_to_graph(). num_replicas = get_num_replicas() if num_replicas <= 1: return 0 with tf.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. return tpu_ops.tpu_replicated_input(list(range(num_replicas)), name="replica_id") def all_sum(structure, name=None): """Sums the contents of a nested structure across all replicas. If not operating in a supported replicated context this function acts like the identity. Args: structure: A nested structure of Tensors. name: None or string. Optional name of Op. (Default: None) Returns: A nested structure with the corresponding Tensors being the cross-replica summed versions of those in `structure`. """ num_replicas = get_num_replicas() if num_replicas <= 1: return structure tf_replicator = get_tf_replicator() if tf_replicator: return tf_replicator.all_sum(structure) elif tf.distribute.has_strategy() and tf.distribute.get_replica_context(): return tf.distribute.get_replica_context().all_reduce( tf.distribute.ReduceOp.SUM, structure) elif is_tpu_replicated(): def tpu_all_sum(tensor): return tpu_ops.cross_replica_sum(tensor, name=name) return nest.map_structure(tpu_all_sum, structure) return structure def all_average(structure, name=None): """Averages the contents of a nested structure across all replicas. If not operating in a supported replicated context this function acts like the identity. Args: structure: A nested structure of Tensors. name: None or string. Optional name of Op. (Default: None) Returns: A nested structure with the corresponding Tensors being the cross-replica averaged versions of those in `structure`. """ num_replicas = get_num_replicas() if num_replicas <= 1: return structure if (tf.distribute.has_strategy() and tf.distribute.get_replica_context() and not get_tf_replicator()): return tf.distribute.get_replica_context().all_reduce( tf.distribute.ReduceOp.MEAN, structure) return nest.map_structure(lambda x: x / num_replicas, all_sum(structure, name=name)) def map_gather(thunks, name=None): """Distributes the execution of thunks over replicas, then gathers results. This method can be used to distribute several expensive computations across the replicas, rather than duplicating the computation in all of them. Args: thunks: A list of thunks that each returns a nested structure of Tensors. These should all have statically known shapes. name: None or string. Optional name of Op. (Default: None) Returns: A list of nested structures of Tensors representing the return values of the list of thunks. """ num_replicas = get_num_replicas() if num_replicas <= 1: return tuple(thunk() for thunk in thunks) tf_replicator = get_tf_replicator() if tf_replicator: return tf_replicator.map_gather(thunks, lambda thunk: thunk()) elif is_tpu_replicated(): replica_id = get_replica_id() def zeros_like(tensor): return tf.zeros(dtype=tensor.dtype, shape=tensor.shape) results = [] for idx, thunk in enumerate(thunks): # TensorFlow's optimization should eliminate the actual computations # done to compute example_structure, using only the (static) shape # information. def make_zeros_thunk(example_structure): def zeros_thunk(): return nest.map_structure(zeros_like, example_structure) return zeros_thunk # This trick of using cross_replica_sum with tensors of zeros is # obviously wasteful in terms of commmunication. A better solution would # involve only communicating the tensors from replicas where `include_me` # was True. include_me = tf.equal(replica_id, idx % num_replicas) results.append( all_sum(tf.cond(include_me, thunk, make_zeros_thunk(thunk()), strict=True), name=name)) return results return tuple(thunk() for thunk in thunks) def ensure_sequence(obj): """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" if isinstance(obj, (tuple, list)): return obj else: return (obj,) def batch_execute(global_step, thunks, batch_size, name=None): """Executes a subset of ops per global step. Given a list of thunks, each of which produces a single stateful op, ensures that exactly 'batch_size' ops are run per global step. Ops are scheduled in a round-robin fashion. For example, with 3 ops global_step | op0 | op1 | op2 ------------+-----+-----+----- 0 | x | x | ------------+-----+-----+----- 1 | x | | x ------------+-----+-----+----- 2 | | x | x ------------+-----+-----+----- 3 | x | x | ------------+-----+-----+----- 4 | x | | x Does not guarantee order of op execution within a single global step. Args: global_step: Tensor indicating time. Determines which ops run. thunks: List of thunks. Each thunk encapsulates one op. Return values are ignored. batch_size: int. Number of ops to execute per global_step. name: string or None. Name scope for newly added ops. Returns: List of ops. Exactly 'batch_size' ops are guaranteed to have an effect every global step. """ def true_fn(thunk): """Ensures thunk is executed and returns an Op (not a Tensor).""" def result(): with tf.control_dependencies([thunk()]): return tf.no_op() return result def false_fn(_): """Executes a no-op.""" def result(): return tf.no_op() return result with tf.name_scope(name, "batch_execute"): true_fns = [true_fn(thunk) for thunk in thunks] false_fns = [false_fn(thunk) for thunk in thunks] num_thunks = len(thunks) conditions = [ tf.less( tf.mod(batch_size - 1 + global_step * batch_size - j, num_thunks), batch_size) for j in range(num_thunks) ] result = [ tf.cond(condition, true_fn, false_fn) for (condition, true_fn, false_fn) in zip(conditions, true_fns, false_fns) ] return result def extract_convolution_patches(inputs, filter_shape, padding, strides=None, dilation_rate=None, name=None, data_format=None): """Extracts inputs to each output coordinate in tf.nn.convolution. This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), where the number of spatial dimensions may be something other than 2. Assumes, - First dimension of inputs is batch_size - Convolution filter is applied to all input channels. Args: inputs: Tensor of shape [batch_size, ..spatial_image_shape.., ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). padding: string. Padding method. One of "VALID", "SAME". strides: None or list of ints. Strides along spatial dimensions. dilation_rate: None or list of ints. Dilation along spatial dimensions. name: None or str. Name of Op. data_format: None or str. Format of data. Returns: Tensor of shape [batch_size, ..spatial_image_shape.., ..spatial_filter_shape.., in_channels] Raises: ValueError: If data_format does not put channel last. ValueError: If inputs and filter disagree on in_channels. """ if not is_data_format_channel_last(data_format): raise ValueError("Channel must be last dimension.") with tf.name_scope(name, "extract_convolution_patches", [inputs, filter_shape, padding, strides, dilation_rate]): batch_size = inputs.shape.as_list()[0] in_channels = inputs.shape.as_list()[-1] # filter_shape = spatial_filter_shape + [in_channels, out_channels] spatial_filter_shape = filter_shape[:-2] if in_channels != filter_shape[-2]: raise ValueError("inputs and filter_shape must agree on in_channels.") # Map each input feature to a location in the output. out_channels = np.prod(spatial_filter_shape) * in_channels filters = tf.eye(out_channels, dtype=inputs.dtype) filters = tf.reshape( filters, list(spatial_filter_shape) + [in_channels, out_channels]) if strides is not None and len(strides) == len(inputs.shape): strides = strides[1:-1] # remove batch and channel dimension if dilation_rate is not None and len(dilation_rate) == len(inputs.shape): dilation_rate = dilation_rate[1:-1] # remove batch and channel dimension result = tf.nn.convolution( inputs, filters, padding=padding, strides=strides, dilation_rate=dilation_rate) spatial_output_shape = result.shape.as_list()[1:-1] result = tf.reshape(result, [batch_size or -1] + spatial_output_shape + list(spatial_filter_shape) + [in_channels]) return result def extract_pointwise_conv2d_patches(inputs, filter_shape, name=None, data_format=None): """Extract patches for a 1x1 conv2d. Args: inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. filter_shape: List of 4 ints. Shape of filter to apply with conv2d() name: None or str. Name for Op. data_format: None or str. Format for data. See 'data_format' in tf.nn.conv2d() for details. Returns: Tensor of shape [batch_size, ..spatial_input_shape.., ..spatial_filter_shape.., in_channels] Raises: ValueError: if inputs is not 4-D. ValueError: if filter_shape is not [1, 1, ?, ?] ValueError: if data_format is not channels-last. """ if inputs.shape.ndims != 4: raise ValueError("inputs must have 4 dims.") if len(filter_shape) != 4: raise ValueError("filter_shape must have 4 dims.") if filter_shape[0] != 1 or filter_shape[1] != 1: raise ValueError("filter_shape must have shape 1 along spatial dimensions.") if not is_data_format_channel_last(data_format): raise ValueError("data_format must be channels last.") with tf.name_scope(name, "extract_pointwise_conv2d_patches", [inputs, filter_shape]): ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. strides = [1, 1, 1, 1] # Operate on all pixels. rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. padding = "VALID" # Doesn't matter. result = tf.extract_image_patches(inputs, ksizes, strides, rates, padding) batch_size, input_height, input_width, in_channels = inputs.shape.as_list() filter_height, filter_width, in_channels, _ = filter_shape return tf.reshape(result, [ batch_size, input_height, input_width, filter_height, filter_width, in_channels ]) def is_data_format_channel_last(data_format): """True if data_format puts channel last.""" if data_format is None: return True return data_format.endswith("C") def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. Args: A: tf.IndexedSlices with dense shape [m, n]. B: tf.Tensor with shape [n, k]. name: str. Name of op. transpose_a: Bool. If true we transpose A before multiplying it by B. (Default: False) transpose_b: Bool. If true we transpose B before multiplying it by A. (Default: False) Returns: tf.IndexedSlices resulting from matmul(A, B). Raises: ValueError: If A doesn't represent a matrix. ValueError: If B is not rank-2. """ with tf.name_scope(name, "matmul_sparse_dense", [A, B]): if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: raise ValueError("A must represent a matrix. Found: %s." % A) if B.shape.ndims != 2: raise ValueError("B must be a matrix.") new_values = tf.matmul( A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) return tf.IndexedSlices( new_values, A.indices, dense_shape=tf.stack([A.dense_shape[0], new_values.shape[1]])) def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. Args: A_diag: diagonal entries of matrix A of shape [m, m]. B: tf.IndexedSlices. Represents matrix of shape [m, n]. name: str. Name of op. Returns: tf.IndexedSlices resulting from matmul(A, B). Raises: ValueError: If A_diag is not rank-1. ValueError: If B doesn't represent a matrix. """ with tf.name_scope(name, "matmul_diag_sparse", [A_diag, B]): A_diag = tf.convert_to_tensor(A_diag) if A_diag.shape.ndims != 1: raise ValueError("A_diag must be a rank-1 Tensor.") if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: raise ValueError("B must represent a matrix. Found: %s." % B) a = tf.gather(A_diag, B.indices) a = tf.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) return tf.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) class AccumulatorVariable(object): """A simple abstraction to accumulate data that we want to average. Basically this variable accumulates data across multiple inputs, and then returns the average of these contributes on command. This accumulation can be reset by the user at any point. """ def __init__(self, name, shape, dtype): """Constructs a new `AccumulatorVariable`. Args: name: `string`. Scope for the variables. shape: shape of the variable. dtype: dtype of the variable. """ with tf.variable_scope(name, reuse=tf.AUTO_REUSE): self._acc_var = tf.get_variable( "acc_var", shape=shape, dtype=dtype, initializer=tf.zeros_initializer(), trainable=False, use_resource=True) # We may be able to make give this a VariableAggregation of # ONLY_FIRST_REPLICA, because we only add 1 or reset it to 0 (it does not # rely on per-replica values). If we do, we can update this in a per- # replica context instead of the cross-replica context. This may improve # efficiency when using a VariableSynchronization of ON_READ. self._counter = tf.get_variable( "counter", shape=(), dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False, use_resource=True) def accumulate(self, value): """Adds `value` to the accumulated data.""" inc_counter_op = smart_assign(self._counter, 1.0, assign_fn=tf.assign_add) acc_op = smart_assign(self._acc_var, value, assign_fn=tf.assign_add) return tf.group(inc_counter_op, acc_op) @property def value(self): """Returns the average of the accumulated values since the last reset.""" return self._acc_var / tf.cast(self._counter, self._acc_var.dtype) def read_value_and_reset(self): """Same as `value` property but resets after the data is read.""" value = self.value with tf.control_dependencies([value]): with tf.control_dependencies([self.reset()]): return tf.identity(value) def reset(self): """Resets the accumulated data to zero.""" var_reset_op = smart_assign( self._acc_var, tf.zeros(self._acc_var.shape, dtype=self._acc_var.dtype)) counter_reset_op = smart_assign(self._counter, tf.constant(0.0, dtype=tf.float32)) return tf.group(var_reset_op, counter_reset_op) class PartitionedTensor(object): """A Tensor partitioned across its 0-th dimension.""" def __init__(self, tensors): """Initializes PartitionedTensor. Args: tensors: List of Tensors. All Tensors must agree on shape (excepting batch dimension) and dtype. Raises: ValueError: If 'tensors' has length zero. ValueError: if contents of 'tensors' don't agree on shape or dtype. """ if not tensors: raise ValueError("tensors must be a list of 1+ Tensors.") dtype = tensors[0].dtype if not all(tensor.dtype == dtype for tensor in tensors): raise ValueError( "all tensors must have the same dtype. The tensors are {}".format( tensors)) shape = tensors[0].shape[1:] if not all(tensor.shape[1:] == shape for tensor in tensors): raise ValueError("All tensors must have shape = %s (excluding batch " "dimension)." % shape) one_hot_depth = getattr(tensors[0], "one_hot_depth", None) if not all( getattr(tensor, "one_hot_depth", None) == one_hot_depth for tensor in tensors): raise ValueError( "All tensors must have one_hot_depth {}".format(one_hot_depth)) self.tensors = tensors @property def shape(self): feature_shape = self.tensors[0].shape[1:] batch_size = sum([tensor.shape[0] for tensor in self.tensors], tf.Dimension(0)) return tf.TensorShape([batch_size]).concatenate(feature_shape) def get_shape(self): return self.shape @property def dtype(self): return self.tensors[0].dtype @property def one_hot_depth(self): return getattr(self.tensors[0], "one_hot_depth", None) def __str__(self): return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) def __hash__(self): return hash(tuple(self.tensors)) def __eq__(self, other): if not isinstance(other, PartitionedTensor): return False return self.tensors == other.tensors def __ne__(self, other): return not self == other # pylint: disable=g-comparison-negation def __getitem__(self, key): return self.as_tensor()[key] def as_tensor(self, dtype=None, name=None, as_ref=False): with tf.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): assert not as_ref assert dtype in [None, self.dtype] return tf.concat(self.tensors, axis=0) @property def device(self): # PartitionedTensors in general do not live on a single device. If the # device cannot be determined unambiguously this property will return None. device = self.tensors[0].device if all(tensor.device == device for tensor in self.tensors): return device return None tf.register_tensor_conversion_function( PartitionedTensor, lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. def _check_match_lists_of_pairs(list1, list2): for (_, var1), (_, var2) in zip(list1, list2): if var1 is not var2: raise ValueError("The variables referenced by the two arguments " "must match.") def sprod(scalar, list_): # Product of scalar with list of items. return tuple(scalar*item for item in list_) def sprod_p(scalar, list_): # Product of scalar with list of (item, var) pairs. return tuple((scalar*item, var) for (item, var) in list_) def sum_(list1, list2): # Element-wise sum of lists of tensors. return tuple(item1 + item2 for item1, item2 in zip(list1, list2)) def sum_p(list1, list2): # Element-wise sum of lists of (tensor, var) pairs. _check_match_lists_of_pairs(list1, list2) return tuple((item1 + item2, var1) for (item1, var1), (item2, var2) in zip(list1, list2)) def ip(list1, list2): # Inner product of lists of tensors. return tf.add_n(tuple(tf.reduce_sum(tensor1 * tensor2) for tensor1, tensor2 in zip(list1, list2))) def ip_p(list1, list2): # Inner product of lists of (tensor, var) pairs. _check_match_lists_of_pairs(list1, list2) return ip(tuple(tensor for (tensor, _) in list1), tuple(tensor for (tensor, _) in list2)) def assert_variables_match_pairs_list(a_and_vars, b_and_vars, error_message=None): """Assert the variables in two lists of (tensor, var) pairs are the same. Args: a_and_vars: a list of (tensor, variable) pairs. b_and_vars: a list of (tensor, variable) pairs. error_message: an optional string prepended to the error message. Raises: ValueError: if any variables in the input pair lists are not the same. """ _, a_variables = zip(*a_and_vars) _, b_variables = zip(*b_and_vars) variable_mismatch_indices = [] for vi, (a_var, b_var) in enumerate(zip(a_variables, b_variables)): if a_var is not b_var: variable_mismatch_indices.append(vi) if variable_mismatch_indices: mismatch_indices_str = ", ".join(map(str, variable_mismatch_indices)) a_variables_str = ", ".join(map(str, a_variables)) b_variables_str = ", ".join(map(str, b_variables)) error_str = ("Mismatch on variable lists at indices {}.\n\nFirst list: {}" "\n\nSecond list: {} \n").format( mismatch_indices_str, a_variables_str, b_variables_str) if error_message: error_str = "{} {}".format(error_message, error_str) raise ValueError(error_str) def multiline_print(lists): """Prints multiple lines of output using tf.print.""" combined_list = [] combined_list += lists[0] # We prepend newline characters to strings at the start of lines to avoid # the ugly space intendations that tf.print's behavior of separating # everything with a space would otherwise cause. for item in lists[1:]: if isinstance(item[0], str): combined_list += (("\n" + item[0],) + item[1:]) else: combined_list += (("\n",) + item) return tf.print(*combined_list) def get_shape(tensor): """Returns list of dimensions using ints only for statically known ones.""" if tensor.shape.dims is None: raise ValueError("Unknown rank for tensor {}.".format(tensor)) static_shape = tensor.shape.as_list() dynamic_shape = tf.shape(tensor) return tuple(elt if elt is not None else dynamic_shape[idx] for idx, elt in enumerate(static_shape)) def cls_name(obj): return obj.__class__.__name__ def is_reference_variable(x): return ((isinstance(x, tf.Variable) and not resource_variable_ops.is_resource_variable(x)) or hasattr(x, "_should_act_as_ref_variable")) class MovingAverageVariable(object): """A variable updated using weighted moving averages. Note that to implement a traditional decaying exponential average one should use a decay value smaller than 1.0 (e.g. 0.9), and set weight = 1.0 - decay. Doing this and setting normalize_value to True will implement "zero-debiased" decayed averages. """ def __init__(self, name, shape, dtype, initializer=tf.zeros_initializer(), normalize_value=True): """Constructs a new `MovingAverageVariable`. Args: name: `string`. Scope for the variables. shape: shape of the variable. dtype: dtype of the variable. initializer: initializer for the variable (see tf.get_variable). Should be tf.zeros_initializer() unless you know what you are doing. (Default: tf.zeros_initializer()) normalize_value: bool. If True we normalize the value property by the total weight (which will be subject to decay). (Default: True) """ self._normalize_value = normalize_value with tf.variable_scope(name, reuse=tf.AUTO_REUSE): self._var = tf.get_variable( "var", shape=shape, dtype=dtype, initializer=initializer, trainable=False, use_resource=True) self._total_weight = tf.get_variable( "total_weight", shape=(), dtype=dtype, initializer=tf.zeros_initializer(), trainable=False, use_resource=True) @property def dtype(self): return self._var.dtype.base_dtype @property def value(self): if self._normalize_value: return self._var / self._total_weight else: return tf.identity(self._var) def add_to_average(self, value, decay=1.0, weight=1.0): """Add a value into the moving average. Args: value: a Tensor matching the shape and dtype that was passed to the constructor. decay: float or 0D Tensor. The current value is multiplied by this before the value is added, as is the total accumulated weight. (Default: 1.0) weight: float or 0D Tensor. The value being added is multiplied by this. Also this is added to the total accumulated weight. (Default: 1.0) """ decay = tf.cast(decay, dtype=self.dtype) weight = tf.cast(weight, dtype=self.dtype) update_var = smart_assign(self._var, decay * self._var + weight * value) update_total_weight = smart_assign(self._total_weight, decay * self._total_weight + weight) return tf.group(update_var, update_total_weight) def reset(self): return tf.group( smart_assign(self._var, tf.zeros_like(self._var)), smart_assign(self._total_weight, tf.zeros_like(self._total_weight)) ) def num_conv_locations(input_shape, filter_shape, strides, padding): """Returns the number of spatial locations a conv kernel is applied to. Args: input_shape: List of ints representing shape of inputs to tf.nn.convolution(). filter_shape: List of ints representing shape of filter to tf.nn.convolution(). strides: List of ints representing strides along spatial dimensions as passed in to tf.nn.convolution(). padding: string representing the padding method, either 'VALID' or 'SAME'. Returns: A scalar |T| denoting the number of spatial locations for the Conv layer. Raises: ValueError: If input_shape, filter_shape don't represent a 1-D or 2-D convolution. """ if len(input_shape) != 4 and len(input_shape) != 3: raise ValueError("input_shape must be length 4, corresponding to a Conv2D," " or length 3, corresponding to a Conv1D.") if len(input_shape) != len(filter_shape): raise ValueError("Inconsistent number of dimensions between input and " "filter for convolution") if strides is None: if len(input_shape) == 4: strides = [1, 1, 1, 1] else: strides = [1, 1, 1] # Use negative integer division to implement 'rounding up'. # Formula for convolution shape taken from: # http://machinelearninguru.com/computer_vision/basics/convolution/convolution_layer.html if len(input_shape) == 3: if padding is not None and padding.lower() == "valid": out_width = -(-(input_shape[1] - filter_shape[0] + 1) // strides[1]) else: out_width = -(-input_shape[1] // strides[1]) return out_width else: if padding is not None and padding.lower() == "valid": out_height = -(-(input_shape[1] - filter_shape[0] + 1) // strides[1]) out_width = -(-(input_shape[2] - filter_shape[1] + 1) // strides[2]) else: out_height = -(-input_shape[1] // strides[1]) out_width = -(-input_shape[2] // strides[2]) return out_height * out_width ================================================ FILE: setup.py ================================================ # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Install kfac.""" from setuptools import find_packages from setuptools import setup setup( name='kfac', version='0.2.4', description='K-FAC for TensorFlow', author='Google Inc.', author_email='no-reply@google.com', url='http://github.com/tensorflow/kfac', license='Apache 2.0', packages=find_packages(exclude=[ 'kfac.examples.*', 'kfac.python.kernel_tests.*', ]), install_requires=[ 'numpy', 'six', 'tensorflow-probability==0.8', 'h5py<3', ], extras_require={ # It's possible that you might need to put tensorflow<2.0 here: 'tensorflow': ['tensorflow>=1.14'], # It's possible that you might need to put tensorflow-gpu<2.0 here: 'tensorflow_gpu': ['tensorflow-gpu>=1.14'], # dm-sonnet<2.0 will force tensorflow<2.0 in the tests: 'tests': ['pytest', 'dm-sonnet<2.0', 'numpy<1.20'], }, classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], keywords='tensorflow machine learning', )