Showing preview only (1,026K chars total). Download the full file or copy to clipboard to get everything.
Repository: tensorflow/kfac
Branch: master
Commit: ddad6375bbde
Files: 69
Total size: 989.3 KB
Directory structure:
gitextract_xp2vjhqo/
├── .travis.yml
├── AUTHORS
├── LICENSE
├── README.md
├── docs/
│ ├── applications.md
│ ├── contact.md
│ ├── examples/
│ │ ├── auto_damp.md
│ │ ├── convolutional.md
│ │ ├── distributed_training.md
│ │ └── parameters.md
│ ├── index.md
│ ├── papers.md
│ └── sitemap.md
├── kfac/
│ ├── __init__.py
│ ├── examples/
│ │ ├── __init__.py
│ │ ├── autoencoder_mnist.py
│ │ ├── autoencoder_mnist_tpu_estimator.py
│ │ ├── autoencoder_mnist_tpu_strategy.py
│ │ ├── classifier_mnist.py
│ │ ├── classifier_mnist_tpu_estimator.py
│ │ ├── convnet.py
│ │ ├── keras/
│ │ │ ├── KFAC_vs_Adam_Experiment.md
│ │ │ ├── KFAC_vs_Adam_on_CIFAR10.ipynb
│ │ │ └── KFAC_vs_Adam_on_CIFAR10_TPU.ipynb
│ │ ├── mnist.py
│ │ └── rnn_mnist.py
│ └── python/
│ ├── __init__.py
│ ├── keras/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── callbacks.py
│ │ ├── optimizers.py
│ │ ├── saving_utils.py
│ │ └── utils.py
│ ├── kernel_tests/
│ │ ├── data_reader_test.py
│ │ ├── estimator_test.py
│ │ ├── graph_search_test.py
│ │ ├── keras_callbacks_test.py
│ │ ├── keras_optimizers_test.py
│ │ ├── keras_saving_utils_test.py
│ │ ├── keras_utils_test.py
│ │ ├── layer_collection_test.py
│ │ ├── loss_functions_test.py
│ │ ├── op_queue_test.py
│ │ ├── optimizer_test.py
│ │ ├── periodic_inv_cov_update_kfac_opt_test.py
│ │ └── utils_test.py
│ └── ops/
│ ├── __init__.py
│ ├── curvature_matrix_vector_products.py
│ ├── estimator.py
│ ├── fisher_blocks.py
│ ├── fisher_factors.py
│ ├── kfac_utils/
│ │ ├── __init__.py
│ │ ├── async_inv_cov_update_kfac_opt.py
│ │ ├── data_reader.py
│ │ ├── data_reader_alt.py
│ │ └── periodic_inv_cov_update_kfac_opt.py
│ ├── layer_collection.py
│ ├── linear_operator.py
│ ├── loss_functions.py
│ ├── op_queue.py
│ ├── optimizer.py
│ ├── placement.py
│ ├── tensormatch/
│ │ ├── __init__.py
│ │ ├── graph_matcher.py
│ │ ├── graph_patterns.py
│ │ ├── graph_search.py
│ │ └── tensorflow_graph_util.py
│ └── utils.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .travis.yml
================================================
language: python
python:
- "3.6"
env:
matrix:
- TF_VERSION="1.15"
install:
- pip install -q "tensorflow==$TF_VERSION"
- pip install -q .[tests]
# Make sure we have the latest version of numpy - avoid problems we were
# seeing with Python 3
- pip install -q -U numpy
script:
# Check import
- python -c "import kfac; print(kfac.LayerCollection.__name__)"
# Run tests
- pytest
git:
depth: 3
================================================
FILE: AUTHORS
================================================
# This is the official list of TensorFlow authors for copyright purposes.
# This file is distinct from the CONTRIBUTORS files.
# See the latter for an explanation.
# Names should be added to this file as:
# Name or Organization <email address>
# The email address is not required for organizations.
Google Inc.
================================================
FILE: LICENSE
================================================
Copyright 2019 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019, The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# K-FAC: Kronecker-Factored Approximate Curvature
[](https://travis-ci.org/tensorflow/kfac)
**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
approximate second-order optimization method, in TensorFlow.
[kfac-paper]: https://arxiv.org/abs/1503.05671
## Installation
`kfac` is compatible with Python 2 and 3 and can be installed directly via
`pip`,
```shell
# Assumes tensorflow or tensorflow-gpu installed
$ pip install kfac
# Installs with tensorflow-gpu requirement
$ pip install 'kfac[tensorflow_gpu]'
# Installs with tensorflow (cpu) requirement
$ pip install 'kfac[tensorflow]'
```
## KFAC DOCS
Please check [KFAC docs][kfac_docs] for a detailed description with examples
of how to use KFAC. Check the [Keras KFAC docs][keras_docs] for information on
using KFAC with Keras.
[kfac_docs]: https://github.com/tensorflow/kfac/tree/master/docs/index.md
[keras_docs]: https://github.com/tensorflow/kfac/tree/master/kfac/python/keras/README.md
================================================
FILE: docs/applications.md
================================================
Coming Soon..
================================================
FILE: docs/contact.md
================================================
Topic | Contact
-------------------- | ---------------------
**Questions** | kfac-users@google.com
**Development Team** | kfac-dev@google.com
Primary contacts:
* James Martens (jamesmartens@google.com)
Contributors (past and present):
* Alok Aggarwal (aloka@google.com)
* Daniel Duckworth (duckworthd@google.com)
* David Pfau (pfau@google.com)
* Dominik Grewe (dominikg@google.com)
* Guodong Zhang (gdzhang@google.com)
* James Keeling (jtkeeling@google.com)
* James Martens (jamesmartens@google.com)
* Jimmy Ba (jba@cs.toronto.edu)
* Lala Li (lala@google.com)
* Matthew Johnson (mattjj@google.com)
* Nicholas Vadivelu (nvadivelu@google.com)
* Noah Siegel (siegeln@google.com)
* Olga Wichrowskaa (olganw@google.com)
* Rishabh Kabra (rkabra@google.com)
* Roger Grosse (rgrosse@google.com)
* Soham De (sohamde@google.com)
* Tamas Berghammer (tberghammer@google.com)
* Vikram Tankasali (tvikram@google.com)
* Zachary Nado (znado@google.com)
================================================
FILE: docs/examples/auto_damp.md
================================================
# Automatic tuning of damping parameter.
## Table of Contents
* [1. Cached Reader](#1-cached-reader)
* [2. Build optimizer and set damping parameters](#2-build-optimizer-and-set-damping-parameters)
* [TIPS:](#tips)
<br>
The [KFAC damping parameter][kfac_damp] can be auto tuned using
Levenberg-Marquardt (LM) algorithm. For a detailed description of the algorithm
refer to `Section 6` of the [KFAC Paper][kfac_paper]. Note this is still a
heuristic and may not always produce optimal results. It can be better or worse
than a carefully tuned fixed value, depending on the problem.
[kfac_paper]: https://arxiv.org/pdf/1503.05671.pdf
[kfac_damp]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md
**Example code**:
https://github.com/tensorflow/kfac/tree/master/kfac/examples/autoencoder_mnist.py
Using this method to auto tune damping requires changes to the basic KFAC
training script, which are described below. We only highlight additional steps
required vs training with a fixed damping value (as in the [Convnet
example][convexamplesec])
[convexamplesec]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md
## 1. Cached Reader
Wrap the dataset into `CachedReader`. This allows us to access previous batch of
data.
```python
cached_reader = data_reader.CachedDataReader(dataset, max_batch_size)
minibatch = cached_reader(batch_size)
```
## 2. Build optimizer and set damping parameters
```python
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
learning_rate=1.0,
damping=150.,
momentum=0.95,
layer_collection=layer_collection,
batch_size=batch_size,
adapt_damping=True,
prev_train_batch=cached_reader.cached_batch,
is_chief=True,
loss_fn=loss_fn,
damping_adaptation_decay=0.95,
damping_adaptation_interval=FLAGS.damping_adaptation_interval,
)
train_op = optimizer.minimize(loss, global_step=global_step)
```
## TIPS:
1. Damping can also be tuned using Population based training ([PBT][PBT_link]).
In our observations PBT works on par with auto tuning using LM algorithm,
although is obviously more computationally expensive. However if you are
already doing PBT for other hyperparams then consider tuning damping using
PBT as well.
[PBT_link]: https://arxiv.org/abs/1711.09846
================================================
FILE: docs/examples/convolutional.md
================================================
# Convolutional
## Table of Contents
* [Build the Model](#build-the-model)
* [Register the layers and loss](#register-the-layers-and-loss)
* [Build the optimizer](#build-the-optimizer)
* [Fit the model](#fit-the-model)
* [TIPS](#tips)
<br>
K-FAC needs to know about the structure of your model in order to effectively
optimize it. In particular, it needs to know about:
1. Each convolutional and feed forward layer's inputs and outputs.
1. All of the model parameters.
1. The type of the loss function and its inputs.
Let's explore how we can use K-FAC to solve digit classification with MNIST
using a simple convolutional model. In the following example we will illustrate
how to use `PeriodicInvCovUpdateOpt` which is a subclass of `KfacOptimizer`.
`PeriodicInvCovUpdateOpt` handles placement and execution of covariance and
inverse ops. We will also illustrate how to register the layers both manually
and automatically using the graph scanner.
**Code**:
https://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet_mnist_single_main.py
## Build the Model
First, we begin by defining a model. In this case, we'll load MNIST and
construct a 5-layer ConvNet. The model has 2 Conv/MaxPool pairs and a final
linear layer. If we are registering the layers manually we need to keep the
inputs and outputs and parameters (weights & bias) around, which is illustrated
here.
```python
# Load a dataset.
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=128,
use_fake_data=use_fake_data,
flatten_images=False)
# Build a ConvNet.
pre0, act0, params0 = conv_layer(
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
pre2, act2, params2 = conv_layer(
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
logits, params4 = linear_layer(
layer_id=4, inputs=flat_act3, output_size=num_labels)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
```
## Register the layers and loss
`layer_collection.auto_register_layers` automatically registers all the layers
for typical/standard models. However one must still manually register the loss
function. In the case of cross-entropy loss functions on softmaxes this amounts
to calling `layer_collection.register_softmax_cross_entropy_loss` with the
logits as an argument. Note that the inputs/outputs of non-parameterized layers
such as max pooling and reshaping _do not_ need to be registered.
```python
# Register parameters with graph_search.
tf.logging.info("Building KFAC Optimizer.")
layer_collection = kfac.LayerCollection()
layer_collection.register_softmax_cross_entropy_loss(logits)
# Set the layer at params0 to use a diagonal approximation
# instead of default Kronecker factor based approximation.
layer_collection.define_linked_parameters(
params0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
layer_collection.auto_register_layers()
```
In the example above we demonstrate how to use a non-default Fisher
approximation (diagonal) for one of the conv layers. (The default is usually
Kronecker-factored.) This is done by calling
`layer_collection.define_linked_parameters`, which identifies the given
variables as being part of a particular layer, and sets the approximation that
is to be used for that layer. Any registrations performed later, whether done by
the graph scanner or performed manually by the user, will use this approximation
(unless overridden by the `approx` argument to the registration function).
Layers can also be registered manually. This is required for types of layers
that the automatic graph scanner doesn't recognize.
Note that One can also use a combination of manual and automatic registration by
calling `auto_register_layers()` after performing some manual registration. Any
layers registered manually before will be ignored by the scanner. We register
each layer's inputs, outputs, and parameters with an instance of
`LayerCollection`. For convolution layers, we use `register_conv2d`. For fully
connected (or linear) layers, `register_fully_connected`.
```python
# Register parameters manually.
tf.logging.info("Building KFAC Optimizer.")
layer_collection = kfac.LayerCollection()
layer_collection.register_softmax_cross_entropy_loss(logits)
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
pre0,
approx=kfac_ff.APPROX_DIAGONAL_NAME)
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
layer_collection.register_fully_connected(params4, flat_act3, logits)
```
In this example we demonstrate how to use a non-default Fisher approximation
(diagonal) for one of the layers. (The default is usually Kronecker-factored.)
This is done by passing `approx=kfac_ff.APPROX_DIAGONAL_NAME` to the
registration function `layer_collection.register_conv2d`. Note that if One has
already used `define_linked_parameters` to set the approximation then it is not
required to specify it again via the `approx` argument.
## Build the optimizer
Finally, we instantiate the optimizer. In addition to the `learning_rate` and
`momentum`, the optimizer has 2 additional hyperparameters,
1. `cov_ema_decay`: Check [hyper parameters][hyper_params] section for more
details.
1. `damping`: This is a critical parameter and needs to be tuned for good
performance. Check [hyper parameters][hyper_params] section for more
details.
[hyper_params]:
https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md
```python
# Train with K-FAC.
global_step = tf.train.get_or_create_global_step()
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
learning_rate=0.0001,
damping=0.001,
momentum=0.9,
cov_ema_decay=0.95,
invert_every=10,
cov_update_every=1,
layer_collection=layer_collection)
train_op = optimizer.minimize(loss, global_step=global_step)
```
## Fit the model
Optimizing with KFAC is similar to using a standard optimizer, where there is an
"update op" that computes and applies the update to the model's parameters.
However, KFAC introduces two additional sets of ops that must also be executed
as part of the algorithm (although not necessarily at every iteration). These
are called the "covariance update ops" and "inverse update ops", respectively.
The covariance update ops update the various "covariance" matrices used to
compute the Fisher block approximations for the layers. The inverse update ops
meanwhile are responsible for computing inverses of the approximate Fisher
blocks (using algorithms that exploit their special structure).
`PeriodicInvCovUpdateKfacOpt`, which is a subclass of `KfacOptimizer` class,
folds these extra ops into the standard update op, so that they execute
periodically on certain iterations, according to the `cov_update_every` and
`invert_every` arguments. Users seeking more fine-grained control of the timing
and placement of the ops can use the base `KfacOptimizer` class.
```python
with tf.train.MonitoredTrainingSession() as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _, _ = sess.run(
[global_step, loss, accuracy, train_op])
```
## TIPS
1. Check the [hyper params tuning][hp_tune] section for more details on tuning
various KFAC parameters.
[hp_tune]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md
[mlp]: https://en.wikipedia.org/wiki/Multilayer_perceptron
[preconditioner]: https://en.wikipedia.org/wiki/Preconditioner#Preconditioning_in_optimization
================================================
FILE: docs/examples/distributed_training.md
================================================
# Distributed Training
## Table of Contents
* [Register the layers](#register-the-layers)
* [Build the optimizer](#build-the-optimizer)
* [Fit the model](#fit-the-model)
* [TIPS](#tips)
<br>
This example showcases how to use K-FAC in a distributed setting using
`SyncReplicas` optimizer. If you are interested in using
`tf.distribute.Strategy`, we support `MirroredStrategy` and `TPUStrategy`, with
an example for `TPUStrategy` [here][tpu_strategy_example]. While most methods
benefit from increased compute, K-FAC particularly shines as the number of
workers (and, in turn, batch size) increases.
[here][https://github.com/tensorflow/kfac/tree/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb]
**Note:** This tutorial extends the single-machine
[Convolutional example][conv_ex] to distributed training. It is highly
recommended you read that first, as shared bits are omitted below!
[conv_ex]:
https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md
**Note:** This tutorial expects you to be familiar with distributed training.
Check out https://www.tensorflow.org/deploy/distributed if this is new to you.
**Example code**:
https://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet_mnist_distributed_main.py
## Build the Model
When training on a single machine, one doesn't need to think about which
"device" a variable is placed on (there's only 1 to choose from!). In a
distributed setting, variables live on ["Parameter Servers"][parameter-servers].
Placing a variable on a parameter server is as simple as using
`tf.train.replica_device_setter()`, which is illustrated in the below code.
```python
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
pre0, act0, w0, b0 = conv_layer(
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
...
```
[parameter-servers]: https://www.tensorflow.org/deploy/distributed
## Register the layers
Layer registration is identical to the single-machine case. See ["Register the
layers"][register-layers-conv] in the Convolutional example for details.
[register-layers-conv]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md?#register-the-layers-and-loss
## Build the optimizer
Like the model itself, the K-FAC optimizer also creates variables. Don't forget
to wrap it in a similar `replica_device_setter()` too!
```python
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
...
optimizer = opt.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
momentum=0.9)
...
```
## Fit the model
When training on a single-machine, a single training loop is responsible for
executing all of K-FAC's training operations: updating weights, updating
statistics, and inverting the preconditioner matrix. As all of the work happens
on a single machine, one stands little to gain by parallelization.
There are different strategies of parallelizing the gradient, covariance and
inverse computation across workers in a distributed setting. We will illustrate
here two such strategies that work specifically with `SyncReplicas` optimizer
for distributed training.
The first strategy for distributed training is to compute gradient in a
distributed fashion across all the workers, but have the inverse and covariance
ops executed only on the chief worker.
**Code**:
https://github.com/tensorflow/kfac/tree/master/kfac/examples/convnet.py
```python
optimizer = opt.KfacOptimizer(...)
sync_optimizer = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
(cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
tf.logging.info("Starting training.")
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
if is_chief:
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
tf.equal(tf.mod(global_step, invert_every), 0),
lambda: make_update_op(inv_update_thunks),
tf.no_op)
with tf.control_dependencies([inverse_op]):
train_op = sync_optimizer.minimize(loss, global_step=global_step)
else:
train_op = sync_optimizer.minimize(loss, global_step=global_step)
```
In the second strategy, each worker's training loop is responsible for executing
only one of K-FAC's three training ops,
1. Compute gradients.
1. Workers updating covariance matrices can asynchronously update the moving
average similar to the way asynchronous SGD updates weights.
1. Workers inverting the preconditioning matrix can independently and
asynchronously invert its blocks, one at a time. Blocks are chosen according
to a randomly shuffled queue.
```python
optimizer = opt.KfacOptimizer(...)
inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values())
sync_optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
train_op = sync_optimizer.minimize(loss, global_step=global_step)
with tf.train.MonitoredTrainingSession(...) as sess:
while not sess.should_stop():
if _is_gradient_task(task_id, num_worker_tasks):
learning_op = train_op
elif _is_cov_update_task(task_id, num_worker_tasks):
learning_op = optimizer.cov_update_op
elif _is_inv_update_task(task_id, num_worker_tasks):
learning_op = inv_update_queue.next_op(sess)
global_step_, loss_, statistics_, _ = sess.run(
[global_step, loss, statistics, learning_op])
```
## TIPS
1. Check the [hyper params tuning][hp_tune] section for more details on tuning
various KFAC parameters.
[hp_tune]: https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md
================================================
FILE: docs/examples/parameters.md
================================================
# K-FAC Parameters.
## Table of Contents
* [Damping](#damping)
* [Learning Rate](#learning-rate)
* [Subsample covariance computation](#subsample-covariance-computation)
* [KFAC norm constraint](#kfac-norm-constraint)
* [Covariance decay](#covariance-decay)
* [Train batch size](#train-batch-size)
<br>
We list below various parameters which can be tuned to improve training and run
time performance of K-FAC.
## Damping
Damping is a crucial aspect of K-FAC, as it is for any second order
optimization/natural gradient method. Broadly speaking, it refers to the
practice of penalizing or constraining the size of the update in various ways so
that it doesn't leave the local region where the quadratic approximation to the
objective (which is used to compute the update) remains accurate. This region
commonly referred to as the "trust region". In some literature damping is called
"regularization" although we will avoid that term due to its related but
distinct meaning as a method to combat overfitting.
The damping strategy used in KFAC is to (approximately) add a multiple of the
identity to the Fisher before inverting it. This is essentially equivalent to
enforcing that the update lie in a spherical trust region centered at the
current location in parameter space.
The `damping` parameter represents the multiple of identity which is used.
Higher values correspond to smaller trust regions, although the precise
relationship between `damping` and the size of the trust region depends on the
scale of the objective, and will vary from iteration to iteration. (If the loss
function is multiplied by scalar 'alpha' then damping should be multiplied by
'alpha' as well.) Higher values of `damping` can allow higher learning rates,
but as damping tends to infinity the KFAC updates will start to resemble regular
gradient descent updates (scaled by `1/damping`).
The `damping` parameter depends on the scale of the loss function. `damping` is
a critical parameter that needs to be tuned. Options for tuning include a grid
sweep (must be simultaneous with learning rate optimization - NOT independent)
or auto-tuned using the Levenberg-Marquardt (LM) algorithm (see the [`Auto
Damping`][auto_damping] section for further details). For grid sweeps a typical
range to consider would be logarithmically spaced values between `1e-5` to
`100`, although the optimal value could be any non-negative real number in
principle (because the scale of the loss is arbitrary). Another option for
tuning `damping` is [`Population based training`][PBT] (PBT).
Refer to section `6` of the [KFAC paper][kfac_paper] for a more detailed
discussion of damping and how it can be used/tuned in KFAC
[auto_damping]:
https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md
[PBT]:
https://arxiv.org/abs/1711.09846
[kfac_paper]:
https://arxiv.org/pdf/1503.05671.pdf
## Learning Rate
Typically sweep over values in the range 1e-5 to 100. It is important to tune
the learning in conjunction with damping, since the two are closely coupled
(higher damping allows higher learning rates). The learning rate can also be
tuned using PBT. Note that the optimal learning rate will be generally different
from the learning rate used for SGD/RMSProp/Adam optimizer.
## Subsample covariance computation
If you are using Conv layers and observe that the KFAC iterations is
significantly slower than Adam or if you run out of memory then a possible
remedy is to use subsampling in the covariance computation. To turn on
subsampling set `kfac_ff.sub_sample_inputs` to `True` and
`kfac_ff.sub_sample_outer_products` to `True`. The former flag subsamples the
batch of inputs used for covariance computation and the later flag subsamples
extracted patches based on the size of the covariance matrix. Check the
documentation of `tensorflow_kfac.fisher_factors` for detailed explanation of
various subsampling parameters. Also check [`Distributed training`][dist_train]
section for how to distribute the computation of these ops over multiple
devices.
[dist_train]:
https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md
## KFAC norm constraint
Scales the K-FAC update so that its approximate Fisher norm is bounded.
Typically use an initial value of 1.0 and tune it using PBT or perform grid
search. Norm constraint can used as an alternative to learning rate schedules.
See Section 5 of the [Distributed Second-Order Optimization using
Kronecker-Factored Approximations][ba_paper] paper for further details.
[ba_paper]:
https://jimmylba.github.io/papers/nsync.pdf
## Covariance decay
During the course of the algorithm, an exponential moving average tracks
statistics for each layer. Slower decays mean that the statistics are based on
more data, but will suffer more from the issue of staleness (because of the
changing model parameters). This parameter can usually be left at its default
value but may occasionally matter for some problems. In such cases some
reasonable values to sweep over are `[0.9, 0.95, 0.99, 0.999]`.
## Train batch size
Typically try using a larger batch size compared to training with
SGD/RMSprop/Adam.
================================================
FILE: docs/index.md
================================================
# Home
Kronecker factored approximate curvature
**K-FAC in TensorFlow** is an implementation of K-FAC, an approximate
second-order optimization method, in TensorFlow. K-FAC can converge much
faster than SGD or Adam on certain neural network architectures (especially when
using larger batch sizes), but may be closer in performance on other
architectures (such as ResNets).
## Table of Contents
* [What is K-FAC?](#what-is-k-fac)
* [Why should I use K-FAC?](#why-should-i-use-k-fac)
* [How do I use K-FAC?](#how-do-i-use-k-fac)
## What is K-FAC?
K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
to the [Natural Gradient][natural_gradient] algorithm designed specifically for
neural networks. It maintains an approximation to the [Fisher Information
matrix][fisher_information], whose inverse is used as a preconditioner for
(stochastic) gradient descent.
K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
However it is slightly more restrictive compared to SGD, Adam as it makes some
assumptions on the structure of the model and the loss function.
Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
are the weights for layer i?"). As such, you must add some additional code while
constructing your model to use K-FAC.
[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
## Why should I use K-FAC?
K-FAC can take advantage of the curvature of the optimization problem, resulting
in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See
reference code [here][autoencoder-code] and plots comparing KFAC with SGD below.

[autoencoder-code]: https://github.com/tensorflow/kfac/tree/master/kfac/examples/autoencoder_mnist.py
## How do I use K-FAC?
Using K-FAC requires three steps,
1. Registering layer inputs, weights, and pre-activations with a
`kfac.LayerCollection`.
2. Register loss functions.
3. Minimizing the loss with a `kfac.PeriodicInvCovUpdateKfacOpt`.
```python
import kfac
# Build model.
w = tf.get_variable("w", ...)
b = tf.get_variable("b", ...)
logits = tf.matmul(x, w) + b
loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
# Register loss.
layer_collection = kfac.LayerCollection()
layer_collection.register_softmax_cross_entropy_loss(logits)
# Register layers.
layer_collection.auto_register_layers()
# Construct training ops.
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(..., layer_collection=layer_collection)
train_op = optimizer.minimize(loss)
# Minimize loss.
with tf.Session() as sess:
...
sess.run([train_op])
```
Check out the Convnet training [example][convexamplesec] for more details. Also
check [`PeriodicInvCovUpdate`][periodicincovupdate] optimizer to see how the
covariance and invariance ops placement and execution can be handled
automatically.
[convexamplesec]: https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md
[periodicincovupdate]: https://github.com/tensorflow/kfac/tree/master/kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py
## Table of contents
* [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md)
* User Guide
* [Keras](https://github.com/tensorflow/kfac/tree/master/kfac/python/keras/README.md)
* [Convolutional](https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md)
* [Auto damping](https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md)
* [Distributed Training](https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md)
* [Parameters](https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md)
* [Applications](https://github.com/tensorflow/kfac/tree/master/docs/applications.md)
* [Some KFAC-Papers](https://github.com/tensorflow/kfac/tree/master/docs/papers.md)
* [Contact](https://github.com/tensorflow/kfac/tree/master/docs/contact.md)
================================================
FILE: docs/papers.md
================================================
* Martens, James, and Roger Grosse.
["Optimizing neural networks with Kronecker-factored approximate curvature."][kfac_paper]
International Conference on Machine Learning. 2015.
* Grosse, Roger, and James Martens.
["A Kronecker-factored approximate Fisher matrix for convolution layers."][kfac_conv_paper]
International Conference on Machine Learning. 2016.
* Ba, Jimmy, Roger Grosse, and James Martens. ["Distributed Second-Order
Optimization using Kronecker-Factored Approximations."][distributed_kfac]
(2016).
* James Martens, Jimmy Ba, Matt Johnson. ["Kronecker-factored Curvature
Approximations for Recurrent Neural Networks."][kfac_rnn_paper] ICLR. 2018.
[kfac_paper]: https://arxiv.org/abs/1503.05671
[kfac_conv_paper]: https://arxiv.org/abs/1602.01407
[kfac_rnn_paper]: https://openreview.net/forum?id=HyMTkQZAb
[distributed_kfac]: https://openreview.net/forum?id=SkkTMpjex
================================================
FILE: docs/sitemap.md
================================================
* [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md)
* User Guide
* [Convolutional](https://github.com/tensorflow/kfac/tree/master/docs/examples/convolutional.md)
* [Auto damping](https://github.com/tensorflow/kfac/tree/master/docs/examples/auto_damp.md)
* [Distributed Training](https://github.com/tensorflow/kfac/tree/master/docs/examples/distributed_training.md)
* [Parameters](https://github.com/tensorflow/kfac/tree/master/docs/examples/parameters.md)
* [Applications](https://github.com/tensorflow/kfac/tree/master/docs/applications.md)
* [Some KFAC-Papers](https://github.com/tensorflow/kfac/tree/master/docs/papers.md)
* [Contact](https://github.com/tensorflow/kfac/tree/master/docs/contact.md)
================================================
FILE: kfac/__init__.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Kronecker-factored Approximate Curvature Optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from kfac.python import keras
from kfac.python.ops import curvature_matrix_vector_products
from kfac.python.ops import estimator
from kfac.python.ops import fisher_blocks
from kfac.python.ops import fisher_factors
from kfac.python.ops import layer_collection
from kfac.python.ops import linear_operator
from kfac.python.ops import loss_functions
from kfac.python.ops import op_queue
from kfac.python.ops import optimizer
from kfac.python.ops import placement
from kfac.python.ops import utils
from kfac.python.ops.kfac_utils import async_inv_cov_update_kfac_opt
from kfac.python.ops.kfac_utils import data_reader
from kfac.python.ops.kfac_utils import data_reader_alt
from kfac.python.ops.kfac_utils import periodic_inv_cov_update_kfac_opt
from kfac.python.ops.tensormatch import graph_matcher
from kfac.python.ops.tensormatch import graph_search
# pylint: enable=unused-import
# pylint: disable=invalid-name
LayerCollection = layer_collection.LayerCollection
KfacOptimizer = optimizer.KfacOptimizer
PeriodicInvCovUpdateKfacOpt = periodic_inv_cov_update_kfac_opt.PeriodicInvCovUpdateKfacOpt
AsyncInvCovUpdateKfacOpt = async_inv_cov_update_kfac_opt.AsyncInvCovUpdateKfacOpt
CurvatureMatrixVectorProductComputer = curvature_matrix_vector_products.CurvatureMatrixVectorProductComputer
# pylint: enable=invalid-name, line-too-long
================================================
FILE: kfac/examples/__init__.py
================================================
================================================
FILE: kfac/examples/autoencoder_mnist.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Full implementation of deep autoencoder experiment from original K-FAC paper.
This script demonstrates training using KFAC optimizer, updating the damping
parameter according to the Levenberg-Marquardt rule, and using the quadratic
model method for adapting the learning rate and momentum parameters.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
# Dependency imports
from absl import flags
import kfac
import sonnet as snt
import tensorflow.compat.v1 as tf
from kfac.examples import mnist
from kfac.python.ops.kfac_utils import data_reader
from kfac.python.ops.kfac_utils import data_reader_alt
# Model parameters
_ENCODER_SIZES = [1000, 500, 250, 30]
_DECODER_SIZES = [250, 500, 1000]
_NONLINEARITY = tf.tanh # Note: sigmoid cannot be used with the default init.
_WEIGHTS_INITIALIZER = None # Default init
flags.DEFINE_integer('train_steps', 10000, 'Number of training steps.')
flags.DEFINE_integer('inverse_update_period', 5,
'# of steps between computing inverse of Fisher factor '
'matrices.')
flags.DEFINE_integer('cov_update_period', 1,
'# of steps between computing covaraiance matrices.')
flags.DEFINE_integer('damping_adaptation_interval', 5,
'# of steps between updating the damping parameter.')
flags.DEFINE_integer('num_burnin_steps', 5, 'Number of steps at the '
'start of training where the optimizer will only perform '
'cov updates.')
flags.DEFINE_integer('seed', 12345, 'Random seed')
flags.DEFINE_float('learning_rate', 3e-3,
'Learning rate to use when lrmu_adaptation="off".')
flags.DEFINE_float('momentum', 0.9,
'Momentum decay value to use when '
'lrmu_adaptation="off" or "only_lr".')
flags.DEFINE_float('damping', 1e-2, 'The fixed damping value to use. This is '
'ignored if adapt_damping is True.')
flags.DEFINE_float('l2_reg', 1e-5,
'L2 regularization applied to weight matrices.')
flags.DEFINE_boolean('update_damping_immediately', True, 'Adapt the damping '
'immediately after the parameter update (i.e. in the same '
'sess.run() call). Only safe if everything is a resource '
'variable.')
flags.DEFINE_boolean('use_batch_size_schedule', True,
'If True then we use the growing mini-batch schedule from '
'the original K-FAC paper.')
flags.DEFINE_integer('batch_size', 1024,
'The size of the mini-batches to use if not using the '
'schedule.')
flags.DEFINE_string('lrmu_adaptation', 'on',
'If set to "on" then we use the quadratic model '
'based learning-rate and momentum adaptation method from '
'the original paper. Note that this only works well in '
'practice when use_batch_size_schedule=True. Can also '
'be set to "off" and "only_lr", which turns '
'it off, or uses a version where the momentum parameter '
'is fixed (resp.).')
flags.DEFINE_boolean('use_alt_data_reader', True,
'If True we use the alternative data reader for MNIST '
'that is faster for small datasets.')
flags.DEFINE_string('device', '/gpu:0',
'The device to run the major ops on.')
flags.DEFINE_boolean('adapt_damping', True,
'If True we use the LM rule for damping adaptation as '
'described in the original K-FAC paper.')
# When using damping adaptation it is advisable to start with a high
# value. This value is probably far too high to use for most neural nets
# if you aren't using damping adaptation. (Although it always depends on
# the scale of the loss.)
flags.DEFINE_float('initial_damping', 150.0,
'The initial damping value to use when adapt_damping is '
'True.')
flags.DEFINE_string('optimizer', 'kfac',
'The optimizer to use. Can be kfac or adam. If adam is '
'used the various K-FAC hyperparameter map roughly on to '
'their Adam equivalents.')
flags.DEFINE_boolean('auto_register_layers', True,
'If True we use the automatic registration feature '
'which relies on scanning the TF graph. Otherwise '
'registration is done manually by this script during '
'the construction of the model.')
flags.DEFINE_boolean('use_keras_model', False,
'If True, we use a Keras version of the autoencoder '
'model. Only works when auto_register_layers=True.')
flags.DEFINE_boolean('use_sequential_for_keras', True,
'If True, we construct the Keras model using the '
'Sequential class.')
flags.DEFINE_boolean('use_control_flow_v2', False, 'If True, we use Control '
'Flow V2. Defaults to False.')
FLAGS = flags.FLAGS
def make_train_op(minibatch,
batch_size,
batch_loss,
layer_collection,
loss_fn,
prev_train_batch=None,
placement_strategy=None,
print_logs=False,
tf_replicator=None):
"""Constructs optimizer and train op.
Args:
minibatch: A list/tuple of Tensors (typically representing the current
mini-batch of input images and labels).
batch_size: Tensor of shape (). Size of the training mini-batch.
batch_loss: Tensor of shape (). Mini-batch loss tensor.
layer_collection: LayerCollection object. Registry for model parameters.
Required when using a K-FAC optimizer.
loss_fn: Function which takes as input a mini-batch and returns the loss.
prev_train_batch: `Tensor` of the previous training batch, can be accessed
from the data_reader.CachedReader cached_batch property. (Default: None)
placement_strategy: `str`, the placement_strategy argument for
`KfacOptimizer`. (Default: None)
print_logs: `Bool`. If True we print logs using K-FAC's built-in
tf.print-based logs printer. (Default: False)
tf_replicator: A Replicator object or None. If not None, K-FAC will set
itself up to work inside of the provided TF-Replicator object.
(Default: None)
Returns:
train_op: Op that can be used to update model parameters.
optimizer: Optimizer used to produce train_op.
Raises:
ValueError: If layer_collection is None when K-FAC is selected as an
optimization method.
"""
global_step = tf.train.get_or_create_global_step()
if FLAGS.optimizer == 'kfac':
if FLAGS.lrmu_adaptation == 'on':
learning_rate = None
momentum = None
momentum_type = 'qmodel'
elif FLAGS.lrmu_adaptation == 'only_lr':
learning_rate = None
momentum = FLAGS.momentum
momentum_type = 'qmodel_fixedmu'
elif FLAGS.lrmu_adaptation == 'off':
learning_rate = FLAGS.learning_rate
momentum = FLAGS.momentum
# momentum_type = 'regular'
momentum_type = 'adam'
if FLAGS.adapt_damping:
damping = FLAGS.initial_damping
else:
damping = FLAGS.damping
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
invert_every=FLAGS.inverse_update_period,
cov_update_every=FLAGS.cov_update_period,
learning_rate=learning_rate,
damping=damping,
cov_ema_decay=0.95,
momentum=momentum,
momentum_type=momentum_type,
layer_collection=layer_collection,
batch_size=batch_size,
num_burnin_steps=FLAGS.num_burnin_steps,
adapt_damping=FLAGS.adapt_damping,
l2_reg=FLAGS.l2_reg,
placement_strategy=placement_strategy,
print_logs=print_logs,
tf_replicator=tf_replicator,
# Note that many of the arguments below don't do anything when
# adapt_damping=False.
update_damping_immediately=FLAGS.update_damping_immediately,
is_chief=True,
prev_train_batch=prev_train_batch, # We don't actually need this unless
# update_damping_immediately is
# False.
loss=batch_loss,
loss_fn=loss_fn,
damping_adaptation_decay=0.95,
damping_adaptation_interval=FLAGS.damping_adaptation_interval,
min_damping=1e-6,
train_batch=minibatch,
)
elif FLAGS.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate=FLAGS.learning_rate,
beta1=FLAGS.momentum,
epsilon=FLAGS.damping,
beta2=0.99)
return optimizer.minimize(batch_loss, global_step=global_step), optimizer
class AutoEncoder(snt.AbstractModule):
"""Simple autoencoder module."""
def __init__(self,
input_size,
regularizers=None,
initializers=None,
custom_getter=None,
name='AutoEncoder'):
super(AutoEncoder, self).__init__(custom_getter=custom_getter, name=name)
if initializers is None:
initializers = {'w': tf.glorot_uniform_initializer(),
'b': tf.zeros_initializer()}
if regularizers is None:
regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),
'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}
with self._enter_variable_scope():
self._encoder = snt.nets.MLP(
output_sizes=_ENCODER_SIZES,
regularizers=regularizers,
initializers=initializers,
custom_getter=custom_getter,
activation=_NONLINEARITY,
activate_final=False)
self._decoder = snt.nets.MLP(
output_sizes=_DECODER_SIZES + [input_size],
regularizers=regularizers,
initializers=initializers,
custom_getter=custom_getter,
activation=_NONLINEARITY,
activate_final=False)
def _build(self, inputs):
code = self._encoder(inputs)
output = self._decoder(code)
return output
class MLPManualReg(snt.AbstractModule):
def __init__(self,
output_sizes,
regularizers=None,
initializers=None,
custom_getter=None,
activation=_NONLINEARITY,
activate_final=False,
name='MLP'):
super(MLPManualReg, self).__init__(custom_getter=custom_getter, name=name)
self._output_sizes = output_sizes
self._activation = activation
self._activate_final = activate_final
with self._enter_variable_scope():
self._layers = [snt.Linear(self._output_sizes[i],
name='linear_{}'.format(i),
initializers=initializers,
regularizers=regularizers,
custom_getter=custom_getter,
use_bias=True)
for i in range(len(self._output_sizes))]
def _build(self, inputs, layer_collection=None):
net = inputs
for i in range(len(self._output_sizes)):
layer_inputs = net
net = self._layers[i](net)
layer_outputs = net
params = (self._layers[i].w, self._layers[i].b)
if layer_collection is not None:
layer_collection.register_fully_connected(params,
layer_inputs,
layer_outputs,
reuse=False)
if i < len(self._output_sizes) - 1 or self._activate_final:
net = self._activation(net)
return net
class AutoEncoderManualReg(snt.AbstractModule):
"""Simple autoencoder module."""
def __init__(self,
input_size,
regularizers=None,
initializers=None,
custom_getter=None,
name='AutoEncoder'):
super(AutoEncoderManualReg, self).__init__(custom_getter=custom_getter,
name=name)
if initializers is None:
initializers = {'w': tf.glorot_uniform_initializer(),
'b': tf.zeros_initializer()}
if regularizers is None:
regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),
'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}
with self._enter_variable_scope():
self._encoder = MLPManualReg(
output_sizes=_ENCODER_SIZES,
regularizers=regularizers,
initializers=initializers,
custom_getter=custom_getter,
activation=_NONLINEARITY,
activate_final=False)
self._decoder = MLPManualReg(
output_sizes=_DECODER_SIZES + [input_size],
regularizers=regularizers,
initializers=initializers,
custom_getter=custom_getter,
activation=_NONLINEARITY,
activate_final=False)
def _build(self, inputs, layer_collection=None):
code = self._encoder(inputs, layer_collection=layer_collection)
output = self._decoder(code, layer_collection=layer_collection)
return output
def get_keras_autoencoder(**input_kwargs):
"""Returns autoencoder made with Keras.
Args:
**input_kwargs: Arguments to pass to tf.keras.layers.Input. You must include
either the 'shape' or 'tensor' kwarg.
Returns:
A tf.keras.Model, the Autoencoder.
"""
layers = tf.keras.layers
regularizers = tf.keras.regularizers
dense_kwargs = {
'kernel_initializer': tf.glorot_uniform_initializer(),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': regularizers.l2(l=FLAGS.l2_reg),
'bias_regularizer': regularizers.l2(l=FLAGS.l2_reg),
}
if FLAGS.use_sequential_for_keras:
model = tf.keras.Sequential()
# Create Encoder
model.add(layers.Input(**input_kwargs))
for size in _ENCODER_SIZES[:-1]:
model.add(layers.Dense(
size, activation=_NONLINEARITY, **dense_kwargs))
model.add(layers.Dense(_ENCODER_SIZES[-1], **dense_kwargs))
# Create Decoder
for size in _DECODER_SIZES:
model.add(layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs))
model.add(layers.Dense(784, **dense_kwargs))
else:
# Make sure you always wrap the input in keras
inputs = layers.Input(**input_kwargs)
x = inputs
# Create Encoder
for size in _ENCODER_SIZES[:-1]:
x = layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)(x)
x = layers.Dense(_ENCODER_SIZES[-1], **dense_kwargs)(x)
# Create Decoder
for size in _DECODER_SIZES:
x = layers.Dense(size, activation=_NONLINEARITY, **dense_kwargs)(x)
x = layers.Dense(784, **dense_kwargs)(x)
model = tf.keras.Model(inputs=inputs, outputs=x)
return model
def compute_squared_error(logits, targets):
"""Compute mean squared error."""
return tf.reduce_sum(
tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0))
def compute_loss(logits=None,
labels=None,
return_error=False,
model=None):
"""Compute loss value."""
if FLAGS.use_keras_model:
total_regularization_loss = tf.add_n(model.losses)
else:
graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_regularization_loss = tf.add_n(graph_regularizers)
loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=labels)
loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))
regularized_loss = loss + total_regularization_loss
if return_error:
squared_error = compute_squared_error(logits, labels)
return regularized_loss, squared_error
return regularized_loss
def load_mnist():
"""Creates MNIST dataset and wraps it inside cached data reader.
Returns:
cached_reader: `data_reader.CachedReader` instance which wraps MNIST
dataset.
num_examples: int. The number of training examples.
"""
# Wrap the data set into cached_reader which provides variable sized training
# and caches the read train batch.
if not FLAGS.use_alt_data_reader:
# Version 1 using data_reader.py (slow!)
dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)
if FLAGS.use_batch_size_schedule:
max_batch_size = num_examples
else:
max_batch_size = FLAGS.batch_size
# Shuffle before repeat is correct unless you want repeat cases in the
# same batch.
dataset = (dataset.shuffle(num_examples).repeat()
.batch(max_batch_size).prefetch(5))
dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
# This version of CachedDataReader requires the dataset to be shuffled
return data_reader.CachedDataReader(dataset, max_batch_size), num_examples
else:
# Version 2 using data_reader_alt.py (faster)
images, labels, num_examples = mnist.load_mnist_as_tensors(
flatten_images=True)
dataset = (images, labels)
# This version of CachedDataReader requires the dataset to NOT be shuffled
return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples
def _get_batch_size_schedule(minibatch_maxsize):
"""Returns training batch size schedule."""
minibatch_maxsize_targetiter = 500
minibatch_startsize = 1000
div = (float(minibatch_maxsize_targetiter-1)
/ math.log(float(minibatch_maxsize)/minibatch_startsize, 2))
return [
min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize)
for k in range(minibatch_maxsize_targetiter)
]
def construct_train_quants():
"""Returns tensors and optimizer required to run the autoencoder."""
with tf.device(FLAGS.device):
# Load dataset.
cached_reader, num_examples = load_mnist()
batch_size_schedule = _get_batch_size_schedule(num_examples)
batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size')
minibatch = cached_reader(batch_size)
features, _ = minibatch
if FLAGS.auto_register_layers:
if FLAGS.use_keras_model:
training_model = get_keras_autoencoder(tensor=features)
else:
training_model = AutoEncoder(784)
else:
training_model = AutoEncoderManualReg(784)
layer_collection = kfac.LayerCollection()
def loss_fn(minibatch, logits=None, return_error=False):
features, _ = minibatch
if logits is None:
logits = training_model(features)
return compute_loss(
logits=logits,
labels=features,
return_error=return_error,
model=training_model)
if FLAGS.use_keras_model:
logits = training_model.output
else:
if FLAGS.auto_register_layers:
logits = training_model(features)
else:
logits = training_model(features, layer_collection=layer_collection)
(batch_loss, batch_error) = loss_fn(
minibatch, logits=logits, return_error=True)
# Make sure never to confuse this with register_softmax_cross_entropy_loss!
layer_collection.register_sigmoid_cross_entropy_loss(logits,
seed=FLAGS.seed + 1)
if FLAGS.auto_register_layers:
layer_collection.auto_register_layers()
# Make training op
train_op, opt = make_train_op(
minibatch,
batch_size,
batch_loss,
layer_collection,
loss_fn=loss_fn,
prev_train_batch=cached_reader.cached_batch)
return train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size
def main(_):
# If using update_damping_immediately resource variables must be enabled.
# Would recommend always enabling them anyway.
if FLAGS.update_damping_immediately:
tf.enable_resource_variables()
if FLAGS.use_control_flow_v2:
tf.enable_control_flow_v2()
if not FLAGS.auto_register_layers and FLAGS.use_keras_model:
raise ValueError('Require auto_register_layers=True when using Keras '
'model.')
tf.set_random_seed(FLAGS.seed)
(train_op, opt, batch_loss, batch_error, batch_size_schedule,
batch_size) = construct_train_quants()
global_step = tf.train.get_or_create_global_step()
if FLAGS.optimizer == 'kfac':
# We need to put the control depenency on train_op here so that we are
# guaranteed to get the up-to-date values of these various quantities.
# Otherwise there is a race condition and we might get the old values,
# nondeterministically. Another solution would be to get these values in
# a separate sess.run call, but this can sometimes cause problems with
# training frameworks that use hooks (see the comments below).
with tf.control_dependencies([train_op]):
learning_rate = opt.learning_rate
momentum = opt.momentum
damping = opt.damping
rho = opt.rho
qmodel_change = opt.qmodel_change
# Without setting allow_soft_placement=True there will be problems when
# the optimizer tries to place certain ops like "mod" on the GPU (which isn't
# supported).
config = tf.ConfigProto(allow_soft_placement=True)
# It's good practice to put everything into a single sess.run call. The
# reason is that certain "training frameworks" like to run hooks at each
# sess.run call, and there is an implicit expectation there will only
# be one sess.run call every "iteration" of the "optimizer". For example,
# a framework might try to print the loss at each sess.run call, causing
# the mini-batch to be advanced, thus completely breaking the "cached
# batch" mechanism that the damping adaptation method may rely on. (Plus
# there will also be the extra cost of having to reevaluate the loss
# twice.) That being said we don't completely do that here because it's
# inconvenient.
# Train model.
with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30,
config=config) as sess:
for _ in range(FLAGS.train_steps):
i = sess.run(global_step)
if FLAGS.use_batch_size_schedule:
batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)]
else:
batch_size_ = FLAGS.batch_size
if FLAGS.optimizer == 'kfac':
(_, batch_loss_, batch_error_, learning_rate_, momentum_, damping_,
rho_, qmodel_change_) = sess.run([train_op, batch_loss, batch_error,
learning_rate, momentum, damping,
rho, qmodel_change],
feed_dict={batch_size: batch_size_})
else:
_, batch_loss_, batch_error_ = sess.run(
[train_op, batch_loss, batch_error],
feed_dict={batch_size: batch_size_})
# Print training stats.
tf.logging.info(
'iteration: %d', i)
tf.logging.info(
'mini-batch size: %d | mini-batch loss = %f | mini-batch error = %f ',
batch_size_, batch_loss_, batch_error_)
if FLAGS.optimizer == 'kfac':
tf.logging.info(
'learning_rate = %f | momentum = %f',
learning_rate_, momentum_)
tf.logging.info(
'damping = %f | rho = %f | qmodel_change = %f',
damping_, rho_, qmodel_change_)
tf.logging.info('----')
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.app.run(main)
================================================
FILE: kfac/examples/autoencoder_mnist_tpu_estimator.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Deep AutoEncoder from Martens & Grosse (2015).
This script demonstrates training on TPUs with TPU Estimator using the KFAC
optimizer, updating the damping parameter according to the
Levenberg-Marquardt rule, and using the quadratic model method for adapting
the learning rate and momentum parameters.
See third_party/tensorflow_kfac/google/examples/ae_tpu_xm_launcher.py
for an example Borg launch script. If you can't access this launch script,
some important things to know about running K-FAC on TPUs (at least for this
example) are that you must use higher-precision matrix multiplications.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
from absl import flags
import kfac
import tensorflow.compat.v1 as tf
from tensorflow.contrib import tpu as contrib_tpu
from kfac.examples import autoencoder_mnist
from kfac.examples import mnist
flags.DEFINE_integer('save_checkpoints_steps', 500,
'Number of iterations between model checkpoints.')
flags.DEFINE_integer('iterations_per_loop', 100,
'Number of iterations in a TPU training loop.')
flags.DEFINE_string('model_dir', '', 'Model dir.')
flags.DEFINE_string('master', None,
'GRPC URL of the master '
'(e.g. grpc://ip.address.of.tpu:8470).')
FLAGS = flags.FLAGS
def make_train_op(minibatch,
batch_loss,
layer_collection,
loss_fn):
"""Constructs optimizer and train op.
Args:
minibatch: Tuple[Tensor, Tensor] representing the current batch of input
images and labels.
batch_loss: Tensor of shape (), Loss with respect to minibatch to be
minimzed.
layer_collection: LayerCollection object. Registry for model parameters.
Required when using a K-FAC optimizer.
loss_fn: A function that when called constructs the graph to compute the
model loss on the current minibatch. Returns a Tensor of the loss scalar.
Returns:
train_op: Op that can be used to update model parameters.
optimizer: The KFAC optimizer used to produce train_op.
Raises:
ValueError: If layer_collection is None when K-FAC is selected as an
optimization method.
"""
# Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own
# cross-replica syncronization automatically!
return autoencoder_mnist.make_train_op(
minibatch=minibatch,
batch_size=minibatch[0].get_shape().as_list()[0],
batch_loss=batch_loss,
layer_collection=layer_collection,
loss_fn=loss_fn,
prev_train_batch=None,
placement_strategy='replica_round_robin',
)
def compute_squared_error(logits, targets):
"""Compute mean squared error."""
return tf.reduce_sum(
tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0))
def compute_loss(logits, labels):
"""Compute loss value."""
graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_regularization_loss = tf.add_n(graph_regularizers)
loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=labels)
loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))
regularized_loss = loss + total_regularization_loss
return regularized_loss
def mnist_input_fn(params):
dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)
# Shuffle before repeat is correct unless you want repeat cases in the
# same batch.
dataset = (
dataset.shuffle(num_examples).repeat().batch(
params['batch_size'],
drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
return dataset
def print_tensors(**tensors):
"""Host call function to print Tensors from the TPU during training."""
print_op = tf.no_op()
for name in sorted(tensors):
with tf.control_dependencies([print_op]):
tensor = tensors[name]
if name in ['error', 'loss']:
tensor = tf.reduce_mean(tensor)
print_op = tf.Print(tensor, [tensor], message=name + '=')
with tf.control_dependencies([print_op]):
return tf.Print(0., [0.], message='------')
def _model_fn(features, labels, mode, params):
"""Estimator model_fn for an autoencoder with adaptive damping."""
del params
layer_collection = kfac.LayerCollection()
training_model_fn = autoencoder_mnist.AutoEncoder(784)
def loss_fn(minibatch, logits=None):
"""Compute the model loss given a batch of inputs.
Args:
minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images
and labels.
logits: `Tensor` for the current batch of logits. If None then reuses the
AutoEncoder to compute them.
Returns:
`Tensor` for the batch loss.
"""
features, labels = minibatch
del labels
if logits is None:
# Note we do not need to do anything like
# `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):`
# here because Sonnet takes care of variable reuse for us as long as we
# call the same `training_model_fn` module. Otherwise we would need to
# use variable reusing here.
logits = training_model_fn(features)
batch_loss = compute_loss(logits=logits, labels=features)
return batch_loss
logits = training_model_fn(features)
pre_update_batch_loss = loss_fn((features, labels), logits=logits)
pre_update_batch_error = compute_squared_error(logits, features)
if mode == tf.estimator.ModeKeys.TRAIN:
# Make sure never to confuse this with register_softmax_cross_entropy_loss!
layer_collection.register_sigmoid_cross_entropy_loss(logits,
seed=FLAGS.seed + 1)
layer_collection.auto_register_layers()
global_step = tf.train.get_or_create_global_step()
train_op, kfac_optimizer = make_train_op(
(features, labels),
pre_update_batch_loss,
layer_collection,
loss_fn)
tensors_to_print = {
'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0),
'momentum': tf.expand_dims(kfac_optimizer.momentum, 0),
'damping': tf.expand_dims(kfac_optimizer.damping, 0),
'global_step': tf.expand_dims(global_step, 0),
'loss': tf.expand_dims(pre_update_batch_loss, 0),
'error': tf.expand_dims(pre_update_batch_error, 0),
}
if FLAGS.adapt_damping:
tensors_to_print['qmodel_change'] = tf.expand_dims(
kfac_optimizer.qmodel_change, 0)
tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0)
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=pre_update_batch_loss,
train_op=train_op,
host_call=(print_tensors, tensors_to_print),
eval_metrics=None)
else: # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}:
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=pre_update_batch_loss,
eval_metrics=None)
def make_tpu_run_config(master, seed, model_dir, iterations_per_loop,
save_checkpoints_steps):
return contrib_tpu.RunConfig(
master=master,
evaluation_master=master,
model_dir=model_dir,
save_checkpoints_steps=save_checkpoints_steps,
cluster=None,
tf_random_seed=seed,
tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop))
def main(argv):
if FLAGS.use_control_flow_v2:
tf.enable_control_flow_v2()
del argv # Unused.
tf.set_random_seed(FLAGS.seed)
# Invert using cholesky decomposition + triangular solve. This is the only
# code path for matrix inversion supported on TPU right now.
kfac.utils.set_global_constants(posdef_inv_method='cholesky')
kfac.fisher_factors.set_global_constants(
eigenvalue_decomposition_threshold=10000)
config = make_tpu_run_config(
FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop,
FLAGS.save_checkpoints_steps)
estimator = contrib_tpu.TPUEstimator(
use_tpu=True,
model_fn=_model_fn,
config=config,
train_batch_size=FLAGS.batch_size,
eval_batch_size=1024)
estimator.train(
input_fn=mnist_input_fn,
max_steps=FLAGS.train_steps,
hooks=[])
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.app.run(main)
================================================
FILE: kfac/examples/autoencoder_mnist_tpu_strategy.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Deep AutoEncoder from Martens & Grosse (2015).
This script demonstrates training on TPUs with TPUStrategy using the KFAC
optimizer, updating the damping parameter according to the
Levenberg-Marquardt rule, and using the quadratic model method for adapting
the learning rate and momentum parameters.
See third_party/tensorflow_kfac/google/examples/ae_tpu_xm_launcher.py
for an example Borg launch script. If you can't access this launch script,
some important things to know about running K-FAC on TPUs (at least for this
example) are that you must use high-precision matrix multiplications.
iterations_per_loop is not relevant when using TPU Strategy, but you must set
it to 1 when using TPU Estimator.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
from absl import flags
import kfac
import tensorflow.compat.v1 as tf
from kfac.examples import autoencoder_mnist
from kfac.examples import mnist
# TODO(znado): figure out the bug with this and update_damping_immediately=True.
# TODO(znado): Add checkpointing code to the training loop.
flags.DEFINE_integer('save_checkpoints_steps', 500,
'Number of iterations between model checkpoints.')
flags.DEFINE_string('model_dir', '', 'Model dir.')
# iterations_per_loop is not used with TPU Strategy. We keep the flag so the
# Estimator launching script can be used.
flags.DEFINE_integer('iterations_per_loop', 1,
'Number of iterations in a TPU training loop.')
flags.DEFINE_string('master', None,
'GRPC URL of the master '
'(e.g. grpc://ip.address.of.tpu:8470).')
FLAGS = flags.FLAGS
def make_train_op(minibatch,
batch_loss,
layer_collection,
loss_fn):
"""Constructs optimizer and train op.
Args:
minibatch: Tuple[Tensor, Tensor] representing the current batch of input
images and labels.
batch_loss: Tensor of shape (), Loss with respect to minibatch to be
minimzed.
layer_collection: LayerCollection object. Registry for model parameters.
Required when using a K-FAC optimizer.
loss_fn: A function that when called constructs the graph to compute the
model loss on the current minibatch. Returns a Tensor of the loss scalar.
Returns:
train_op: Op that can be used to update model parameters.
optimizer: The KFAC optimizer used to produce train_op.
Raises:
ValueError: If layer_collection is None when K-FAC is selected as an
optimization method.
"""
# Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own
# cross-replica syncronization automatically!
return autoencoder_mnist.make_train_op(
minibatch=minibatch,
batch_size=minibatch[0].get_shape().as_list()[0],
batch_loss=batch_loss,
layer_collection=layer_collection,
loss_fn=loss_fn,
prev_train_batch=None,
placement_strategy='replica_round_robin',
)
def compute_squared_error(logits, targets):
"""Compute mean squared error."""
return tf.reduce_sum(
tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)), axis=0))
def compute_loss(logits, labels, model):
"""Compute loss value."""
loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=labels)
regularization_loss = tf.reduce_sum(model.losses)
crossentropy_loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))
return crossentropy_loss + regularization_loss
def mnist_input_fn(batch_size):
dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)
# Shuffle before repeat is correct unless you want repeat cases in the
# same batch.
dataset = (dataset.shuffle(num_examples)
.repeat()
.batch(batch_size, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
return dataset
def _train_step(batch):
"""Estimator model_fn for an autoencoder with adaptive damping."""
features, labels = batch
model = autoencoder_mnist.get_keras_autoencoder(tensor=features)
def loss_fn(minibatch, logits=None):
"""Compute the model loss given a batch of inputs.
Args:
minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images
and labels.
logits: `Tensor` for the current batch of logits. If None then reuses the
AutoEncoder to compute them.
Returns:
`Tensor` for the batch loss.
"""
features, labels = minibatch
del labels
if logits is None:
logits = model(features)
batch_loss = compute_loss(logits=logits, labels=features, model=model)
return batch_loss
logits = model.output
pre_update_batch_loss = loss_fn((features, labels), logits)
pre_update_batch_error = compute_squared_error(logits, features)
# binary_crossentropy corresponds to sigmoid_crossentropy.
layer_collection = kfac.keras.utils.get_layer_collection(
model, 'binary_crossentropy', seed=FLAGS.seed + 1)
global_step = tf.train.get_or_create_global_step()
train_op, kfac_optimizer = make_train_op(
(features, labels),
pre_update_batch_loss,
layer_collection,
loss_fn)
tensors_to_print = {
'learning_rate': kfac_optimizer.learning_rate,
'momentum': kfac_optimizer.momentum,
'damping': kfac_optimizer.damping,
'global_step': global_step,
'loss': pre_update_batch_loss,
'error': pre_update_batch_error,
}
if FLAGS.adapt_damping:
tensors_to_print['qmodel_change'] = kfac_optimizer.qmodel_change
tensors_to_print['rho'] = kfac_optimizer.rho
with tf.control_dependencies([train_op]):
return {k: tf.identity(v) for k, v in tensors_to_print.items()}
def train():
"""Trains the Autoencoder using TPU Strategy."""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.master)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
with tpu_strategy.scope():
data = mnist_input_fn(batch_size=FLAGS.batch_size)
train_iterator = tpu_strategy.make_dataset_iterator(data)
tensor_dict = tpu_strategy.experimental_run(_train_step, train_iterator)
for k, v in tensor_dict.items():
if k in ('loss', 'error'): # Losses are NOT scaled for num replicas.
tensor_dict[k] = tpu_strategy.reduce(tf.distribute.ReduceOp.MEAN, v)
else: # Other tensors (hyperparameters) are identical across replicas.
# experimental_local_results gives you a tuple of per-replica values.
tensor_dict[k] = tpu_strategy.experimental_local_results(v)
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(cluster_resolver.master(), config=config) as session:
session.run(tf.global_variables_initializer())
session.run(train_iterator.initializer)
print('Starting training.')
for step in range(FLAGS.train_steps):
values_dict = session.run(tensor_dict)
print('Training Step: {}'.format(step))
for k, v in values_dict.items():
print('{}: {}'.format(k, v))
print('Done training.')
def main(argv):
del argv # Unused.
tf.set_random_seed(FLAGS.seed)
# Invert using cholesky decomposition + triangular solve. This is the only
# code path for matrix inversion supported on TPU right now.
kfac.utils.set_global_constants(posdef_inv_method='cholesky')
kfac.fisher_factors.set_global_constants(
eigenvalue_decomposition_threshold=10000)
train()
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.app.run(main)
================================================
FILE: kfac/examples/classifier_mnist.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple MNIST classifier example.
This script demonstrates training using KFAC optimizer, updating the damping
parameter according to the Levenberg-Marquardt rule, and using the quadratic
model method for adapting the learning rate and momentum parameters.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
# Dependency imports
from absl import flags
import kfac
import sonnet as snt
import tensorflow.compat.v1 as tf
from kfac.examples import mnist
from kfac.python.ops.kfac_utils import data_reader
from kfac.python.ops.kfac_utils import data_reader_alt
# Model parameters
_NONLINEARITY = tf.nn.relu # can also be tf.nn.tanh
_POOL = 'MAX' # can also be 'AVG'
flags.DEFINE_integer('train_steps', 10000, 'Number of training steps.')
flags.DEFINE_integer('inverse_update_period', 5,
'# of steps between computing inverse of Fisher factor '
'matrices.')
flags.DEFINE_integer('cov_update_period', 1,
'# of steps between computing covaraiance matrices.')
flags.DEFINE_integer('damping_adaptation_interval', 5,
'# of steps between updating the damping parameter.')
flags.DEFINE_integer('num_burnin_steps', 5, 'Number of steps the at the '
'start of training where the optimizer will only perform '
'cov updates. Will not work on CrossShardOptimizer. See '
'PeriodicInvCovUpdateKfacOpt for details.')
flags.DEFINE_integer('seed', 12345, 'Random seed')
flags.DEFINE_float('learning_rate', 3e-3,
'Learning rate to use when lrmu_adaptation="off".')
flags.DEFINE_float('momentum', 0.9,
'Momentum decay value to use when '
'lrmu_adaptation="off" or "only_lr".')
flags.DEFINE_float('damping', 1e-2, 'The fixed damping value to use. This is '
'ignored if adapt_damping is True.')
flags.DEFINE_float('l2_reg', 1e-5,
'L2 regularization applied to weight matrices.')
flags.DEFINE_boolean('update_damping_immediately', True, 'Adapt the damping '
'immediately after the parameter update (i.e. in the same '
'sess.run() call). Only safe if everything is a resource '
'variable.')
flags.DEFINE_boolean('use_batch_size_schedule', True,
'If True then we use the growing mini-batch schedule from '
'the original K-FAC paper.')
flags.DEFINE_integer('batch_size', 1024,
'The size of the mini-batches to use if not using the '
'schedule.')
flags.DEFINE_string('lrmu_adaptation', 'on',
'If set to "on" then we use the quadratic model '
'based learning-rate and momentum adaptation method from '
'the original paper. Note that this only works well in '
'practice when use_batch_size_schedule=True. Can also '
'be set to "off" and "only_lr", which turns '
'it off, or uses a version where the momentum parameter '
'is fixed (resp.).')
flags.DEFINE_boolean('use_alt_data_reader', True,
'If True we use the alternative data reader for MNIST '
'that is faster for small datasets.')
flags.DEFINE_string('device', '/gpu:0',
'The device to run the major ops on.')
flags.DEFINE_boolean('adapt_damping', True,
'If True we use the LM rule for damping adaptation as '
'described in the original K-FAC paper.')
# When using damping adaptation it is advisable to start with a high
# value. This value is probably far too high to use for most neural nets
# if you aren't using damping adaptation. (Although it always depends on
# the scale of the loss.)
flags.DEFINE_float('initial_damping', 0.1,
'The initial damping value to use when adapt_damping is '
'True.')
flags.DEFINE_string('optimizer', 'kfac',
'The optimizer to use. Can be kfac or adam. If adam is '
'used the various kfac hyperparameter map roughly on to '
'their Adam equivalents.')
flags.DEFINE_float('polyak_decay', 0.995, 'Rate of decay for Polyak averaging.')
flags.DEFINE_integer('eval_every', 50,
'Interval to print total training loss.')
flags.DEFINE_boolean('use_sua_approx', False,
'If True we use the SUA approximation for conv layers.')
flags.DEFINE_string('dtype', 'float32',
'The DTYPE to use for all computations. Can by float32 '
'or float64.')
flags.DEFINE_boolean('use_custom_patches_op', False,
'If True we use the custom XLA implementation of the op '
'which computes the second moment of patch vectors.')
FLAGS = flags.FLAGS
class Model(snt.AbstractModule):
"""CNN model for MNIST data."""
def _build(self, inputs):
if FLAGS.l2_reg:
regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),
'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}
else:
regularizers = None
reshape = snt.BatchReshape([28, 28, 1])
conv = snt.Conv2D(2, 5, padding=snt.SAME, regularizers=regularizers)
act = _NONLINEARITY(conv(reshape(inputs)))
pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,
padding=snt.SAME, strides=(2, 2))
conv = snt.Conv2D(4, 5, padding=snt.SAME, regularizers=regularizers)
act = _NONLINEARITY(conv(pool))
pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,
padding=snt.SAME, strides=(2, 2))
flatten = snt.BatchFlatten()(pool)
linear = snt.Linear(32, regularizers=regularizers)(flatten)
return snt.Linear(10, regularizers=regularizers)(linear)
def make_train_op(minibatch,
batch_size,
batch_loss,
layer_collection,
loss_fn,
prev_train_batch=None,
placement_strategy=None,
print_logs=False,
tf_replicator=None):
"""Constructs optimizer and train op.
Args:
minibatch: A list/tuple of Tensors (typically representing the current
mini-batch of input images and labels).
batch_size: Tensor of shape (). Size of the training mini-batch.
batch_loss: Tensor of shape (). Mini-batch loss tensor.
layer_collection: LayerCollection object. Registry for model parameters.
Required when using a K-FAC optimizer.
loss_fn: Function which takes as input a mini-batch and returns the loss.
prev_train_batch: `Tensor` of the previous training batch, can be accessed
from the data_reader.CachedReader cached_batch property. (Default: None)
placement_strategy: `str`, the placement_strategy argument for
`KfacOptimizer`. (Default: None)
print_logs: `Bool`. If True we print logs using K-FAC's built-in
tf.print-based logs printer. (Default: False)
tf_replicator: A Replicator object or None. If not None, K-FAC will set
itself up to work inside of the provided TF-Replicator object.
(Default: None)
Returns:
train_op: Op that can be used to update model parameters.
optimizer: Optimizer used to produce train_op.
Raises:
ValueError: If layer_collection is None when K-FAC is selected as an
optimization method.
"""
global_step = tf.train.get_or_create_global_step()
if FLAGS.optimizer == 'kfac':
if FLAGS.lrmu_adaptation == 'on':
learning_rate = None
momentum = None
momentum_type = 'qmodel'
elif FLAGS.lrmu_adaptation == 'only_lr':
learning_rate = None
momentum = FLAGS.momentum
momentum_type = 'qmodel_fixedmu'
elif FLAGS.lrmu_adaptation == 'off':
learning_rate = FLAGS.learning_rate
momentum = FLAGS.momentum
# momentum_type = 'regular'
momentum_type = 'adam'
if FLAGS.adapt_damping:
damping = FLAGS.initial_damping
else:
damping = FLAGS.damping
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
invert_every=FLAGS.inverse_update_period,
cov_update_every=FLAGS.cov_update_period,
learning_rate=learning_rate,
damping=damping,
cov_ema_decay=0.95,
momentum=momentum,
momentum_type=momentum_type,
layer_collection=layer_collection,
batch_size=batch_size,
num_burnin_steps=FLAGS.num_burnin_steps,
adapt_damping=FLAGS.adapt_damping,
# Note that many of the arguments below don't do anything when
# adapt_damping=False.
update_damping_immediately=FLAGS.update_damping_immediately,
is_chief=True,
prev_train_batch=prev_train_batch,
loss=batch_loss,
loss_fn=loss_fn,
damping_adaptation_decay=0.9,
damping_adaptation_interval=FLAGS.damping_adaptation_interval,
min_damping=1e-6,
l2_reg=FLAGS.l2_reg,
train_batch=minibatch,
placement_strategy=placement_strategy,
print_logs=print_logs,
tf_replicator=tf_replicator,
dtype=FLAGS.dtype,
)
elif FLAGS.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate=FLAGS.learning_rate,
beta1=FLAGS.momentum,
epsilon=FLAGS.damping,
beta2=0.99)
return optimizer.minimize(batch_loss, global_step=global_step), optimizer
def compute_loss(logits=None,
labels=None,
return_error=False,
use_regularizer=True):
"""Compute loss value."""
loss_matrix = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
labels=labels)
total_loss = tf.reduce_mean(loss_matrix, axis=0)
if use_regularizer:
graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_regularization_loss = tf.add_n(graph_regularizers)
total_loss += tf.cast(total_regularization_loss, dtype=total_loss.dtype)
if return_error:
error = 1.0 - tf.reduce_mean(tf.cast(
tf.equal(labels, tf.argmax(logits, axis=1, output_type=tf.int32)),
tf.float32))
return total_loss, error
return total_loss
def load_mnist():
"""Creates MNIST dataset and wraps it inside cached data reader.
Returns:
cached_reader: `data_reader.CachedReader` instance which wraps MNIST
dataset.
num_examples: int. The number of training examples.
"""
# Wrap the data set into cached_reader which provides variable sized training
# and caches the read train batch.
if not FLAGS.use_alt_data_reader:
# Version 1 using data_reader.py (slow!)
dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)
if FLAGS.use_batch_size_schedule:
max_batch_size = num_examples
else:
max_batch_size = FLAGS.batch_size
# Shuffle before repeat is correct unless you want repeat cases in the
# same batch.
dataset = (dataset.shuffle(num_examples).repeat()
.batch(max_batch_size).prefetch(5))
dataset = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
# This version of CachedDataReader requires the dataset to be shuffled
return data_reader.CachedDataReader(dataset, max_batch_size), num_examples
else:
# Version 2 using data_reader_alt.py (faster)
images, labels, num_examples = mnist.load_mnist_as_tensors(
flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype))
dataset = (images, labels)
# This version of CachedDataReader requires the dataset to NOT be shuffled
return data_reader_alt.CachedDataReader(dataset, num_examples), num_examples
def _get_batch_size_schedule(num_examples):
"""Returns training batch size schedule."""
minibatch_maxsize_targetiter = 100 # We use a smaller target iter here than
# in the autoencoder example.
minibatch_maxsize = num_examples
minibatch_startsize = 1000
div = (float(minibatch_maxsize_targetiter-1)
/ math.log(float(minibatch_maxsize)/minibatch_startsize, 2))
return [
min(int(2.**(float(k)/div) * minibatch_startsize), minibatch_maxsize)
for k in range(minibatch_maxsize_targetiter)
]
def group_assign(dest, source):
return tf.group(*(d.assign(s) for d, s in zip(dest, source)))
def make_eval_ops(train_vars, ema):
# This does evaluation with and without Polyak averaging.
images, labels, _ = mnist.load_mnist_as_tensors(
flatten_images=True, dtype=tf.dtypes.as_dtype(FLAGS.dtype))
eval_model = Model()
eval_model(images) # We need this dummy call because the variables won't
# exist otherwise.
eval_vars = eval_model.variables
update_eval_model = group_assign(eval_vars, train_vars)
with tf.control_dependencies([update_eval_model]):
logits = eval_model(images)
eval_loss, eval_error = compute_loss(
logits=logits, labels=labels, return_error=True)
with tf.control_dependencies([eval_loss, eval_error]):
update_eval_model_avg = group_assign(
eval_vars, (ema.average(t) for t in train_vars))
with tf.control_dependencies([update_eval_model_avg]):
logits = eval_model(images)
eval_loss_avg, eval_error_avg = compute_loss(
logits=logits, labels=labels, return_error=True)
return eval_loss, eval_error, eval_loss_avg, eval_error_avg
def construct_train_quants():
with tf.device(FLAGS.device):
# Load dataset.
cached_reader, num_examples = load_mnist()
batch_size_schedule = _get_batch_size_schedule(num_examples)
batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size')
minibatch = cached_reader(batch_size)
features, _ = minibatch
training_model = Model()
layer_collection = kfac.LayerCollection()
if FLAGS.use_sua_approx:
layer_collection.set_default_conv2d_approximation('kron_sua')
ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay,
zero_debias=True)
def loss_fn(minibatch, logits=None, return_error=False):
features, labels = minibatch
if logits is None:
logits = training_model(features)
return compute_loss(
logits=logits,
labels=labels,
return_error=return_error)
logits = training_model(features)
(batch_loss, batch_error) = loss_fn(
minibatch, logits=logits, return_error=True)
# Make sure never to confuse this with register_sigmoid_cross_entropy_loss!
layer_collection.register_softmax_cross_entropy_loss(logits,
seed=FLAGS.seed + 1)
layer_collection.auto_register_layers()
train_vars = training_model.variables
# Make training op:
train_op, opt = make_train_op(
minibatch,
batch_size,
batch_loss,
layer_collection,
loss_fn=loss_fn,
prev_train_batch=cached_reader.cached_batch)
with tf.control_dependencies([train_op]):
train_op = ema.apply(train_vars)
# We clear out the regularizers collection when creating our evaluation
# graph (which uses different variables). It is important that we do this
# only after the train op is constructed, since the minimize() will call
# into the loss function (which includes the regularizer):
tf.get_default_graph().clear_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
# These aren't run in the same sess.run call as train_op:
(eval_loss, eval_error,
eval_loss_avg, eval_error_avg) = make_eval_ops(train_vars, ema)
return (train_op, opt, batch_loss, batch_error, batch_size_schedule,
batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg)
def main(_):
# If using update_damping_immediately resource variables must be enabled.
if FLAGS.update_damping_immediately:
tf.enable_resource_variables()
if not FLAGS.use_sua_approx:
if FLAGS.use_custom_patches_op:
kfac.fisher_factors.set_global_constants(
use_patches_second_moment_op=True
)
else:
# Temporary measure to save memory with giant batches:
kfac.fisher_factors.set_global_constants(
sub_sample_inputs=True,
inputs_to_extract_patches_factor=0.2)
tf.set_random_seed(FLAGS.seed)
(train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size,
eval_loss, eval_error,
eval_loss_avg, eval_error_avg) = construct_train_quants()
global_step = tf.train.get_or_create_global_step()
if FLAGS.optimizer == 'kfac':
# We need to put the control depenency on train_op here so that we are
# guaranteed to get the up-to-date values of these various quantities.
# Otherwise there is a race condition and we might get the old values,
# nondeterministically. Another solution would be to get these values in
# a separate sess.run call, but this can sometimes cause problems with
# training frameworks that use hooks (see the comments below).
with tf.control_dependencies([train_op]):
learning_rate = opt.learning_rate
momentum = opt.momentum
damping = opt.damping
rho = opt.rho
qmodel_change = opt.qmodel_change
# Without setting allow_soft_placement=True there will be problems when
# the optimizer tries to place certain ops like "mod" on the GPU (which isn't
# supported).
config = tf.ConfigProto(allow_soft_placement=True)
# Train model.
# It's good practice to put everything into a single sess.run call. The
# reason is that certain "training frameworks" like to run hooks at each
# sess.run call, and there is an implicit expectation there will only
# be one sess.run call every "iteration" of the "optimizer". For example,
# a framework might try to print the loss at each sess.run call, causing
# the mini-batch to be advanced, thus completely breaking the "cached
# batch" mechanism that the damping adaptation method may rely on. (Plus
# there will also be the extra cost of having to reevaluate the loss
# twice.) That being said we don't completely do that here because it's
# inconvenient.
with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30,
config=config) as sess:
for _ in range(FLAGS.train_steps):
i = sess.run(global_step)
if FLAGS.use_batch_size_schedule:
batch_size_ = batch_size_schedule[min(i, len(batch_size_schedule) - 1)]
else:
batch_size_ = FLAGS.batch_size
if FLAGS.optimizer == 'kfac':
(_, batch_loss_, batch_error_, learning_rate_, momentum_, damping_,
rho_, qmodel_change_) = sess.run([train_op, batch_loss, batch_error,
learning_rate, momentum, damping,
rho, qmodel_change],
feed_dict={batch_size: batch_size_})
else:
_, batch_loss_, batch_error_ = sess.run(
[train_op, batch_loss, batch_error],
feed_dict={batch_size: batch_size_})
# Print training stats.
tf.logging.info(
'iteration: %d', i)
tf.logging.info(
'mini-batch size: %d | mini-batch loss = %f | mini-batch error = %f ',
batch_size_, batch_loss_, batch_error_)
if FLAGS.optimizer == 'kfac':
tf.logging.info(
'learning_rate = %f | momentum = %f',
learning_rate_, momentum_)
tf.logging.info(
'damping = %f | rho = %f | qmodel_change = %f',
damping_, rho_, qmodel_change_)
# "Eval" here means just compute stuff on the full training set.
if (i+1) % FLAGS.eval_every == 0:
eval_loss_, eval_error_, eval_loss_avg_, eval_error_avg_ = sess.run(
[eval_loss, eval_error, eval_loss_avg, eval_error_avg])
tf.logging.info('-----------------------------------------------------')
tf.logging.info('eval_loss = %f | eval_error = %f',
eval_loss_, eval_error_)
tf.logging.info('eval_loss_avg = %f | eval_error_avg = %f',
eval_loss_avg_, eval_error_avg_)
tf.logging.info('-----------------------------------------------------')
else:
tf.logging.info('----')
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.app.run(main)
================================================
FILE: kfac/examples/classifier_mnist_tpu_estimator.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple MNIST classifier example.
This script demonstrates training on TPUs with TPU Estimator using the KFAC
optimizer, updating the damping parameter according to the
Levenberg-Marquardt rule, and using the quadratic model method for adapting
the learning rate and momentum parameters.
See third_party/tensorflow_kfac/google/examples/classifier_tpu_xm_launcher.py
for an example Borg launch script. If you can't access this launch script,
some important things to know about running K-FAC on TPUs (at least for this
example) are that you must use higher-precision matrix multiplications.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
from absl import flags
import kfac
import tensorflow.compat.v1 as tf
from tensorflow.contrib import tpu as contrib_tpu
from kfac.examples import classifier_mnist
from kfac.examples import mnist
flags.DEFINE_integer('save_checkpoints_steps', 500,
'Number of iterations between model checkpoints.')
flags.DEFINE_integer('iterations_per_loop', 100,
'Number of iterations in a TPU training loop.')
flags.DEFINE_string('model_dir', '', 'Model dir.')
flags.DEFINE_string('master', None,
'GRPC URL of the master '
'(e.g. grpc://ip.address.of.tpu:8470).')
FLAGS = flags.FLAGS
def make_train_op(minibatch,
batch_loss,
layer_collection,
loss_fn):
"""Constructs optimizer and train op.
Args:
minibatch: Tuple[Tensor, Tensor] representing the current batch of input
images and labels.
batch_loss: Tensor of shape (), Loss with respect to minibatch to be
minimzed.
layer_collection: LayerCollection object. Registry for model parameters.
Required when using a K-FAC optimizer.
loss_fn: A function that when called constructs the graph to compute the
model loss on the current minibatch. Returns a Tensor of the loss scalar.
Returns:
train_op: Op that can be used to update model parameters.
optimizer: The KFAC optimizer used to produce train_op.
Raises:
ValueError: If layer_collection is None when K-FAC is selected as an
optimization method.
"""
# Do not use CrossShardOptimizer with K-FAC. K-FAC now handles its own
# cross-replica syncronization automatically!
return classifier_mnist.make_train_op(
minibatch=minibatch,
batch_size=minibatch[0].get_shape().as_list()[0],
batch_loss=batch_loss,
layer_collection=layer_collection,
loss_fn=loss_fn,
prev_train_batch=None,
placement_strategy='replica_round_robin',
)
def mnist_input_fn(params):
dataset, num_examples = mnist.load_mnist_as_dataset(flatten_images=True)
# Shuffle before repeat is correct unless you want repeat cases in the
# same batch.
dataset = (
dataset.shuffle(num_examples).repeat().batch(
params['batch_size'],
drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
return dataset
def print_tensors(**tensors):
"""Host call function to print Tensors from the TPU during training."""
print_op = tf.no_op()
for name in sorted(tensors):
with tf.control_dependencies([print_op]):
tensor = tensors[name]
if name in ['error', 'loss']:
tensor = tf.reduce_mean(tensor)
print_op = tf.Print(tensor, [tensor], message=name + '=')
with tf.control_dependencies([print_op]):
return tf.Print(0., [0.], message='------')
def _model_fn(features, labels, mode, params):
"""Estimator model_fn for an autoencoder with adaptive damping."""
del params
training_model = classifier_mnist.Model()
layer_collection = kfac.LayerCollection()
def loss_fn(minibatch,
logits=None,
return_error=False):
features, labels = minibatch
if logits is None:
# Note we do not need to do anything like
# `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):`
# here because Sonnet takes care of variable reuse for us as long as we
# call the same `training_model` module. Otherwise we would need to
# use variable reusing here.
logits = training_model(features)
return classifier_mnist.compute_loss(logits=logits,
labels=labels,
return_error=return_error)
logits = training_model(features)
pre_update_batch_loss, pre_update_batch_error = loss_fn(
(features, labels),
logits=logits,
return_error=True)
global_step = tf.train.get_or_create_global_step()
if mode == tf.estimator.ModeKeys.TRAIN:
layer_collection.register_softmax_cross_entropy_loss(logits,
seed=FLAGS.seed + 1)
layer_collection.auto_register_layers()
train_op, kfac_optimizer = make_train_op(
(features, labels),
pre_update_batch_loss,
layer_collection,
loss_fn)
tensors_to_print = {
'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0),
'momentum': tf.expand_dims(kfac_optimizer.momentum, 0),
'damping': tf.expand_dims(kfac_optimizer.damping, 0),
'global_step': tf.expand_dims(global_step, 0),
'loss': tf.expand_dims(pre_update_batch_loss, 0),
'error': tf.expand_dims(pre_update_batch_error, 0),
}
if FLAGS.adapt_damping:
tensors_to_print['qmodel_change'] = tf.expand_dims(
kfac_optimizer.qmodel_change, 0)
tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0)
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=pre_update_batch_loss,
train_op=train_op,
host_call=(print_tensors, tensors_to_print),
eval_metrics=None)
else: # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}:
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=pre_update_batch_loss,
eval_metrics=None)
def make_tpu_run_config(master, seed, model_dir, iterations_per_loop,
save_checkpoints_steps):
return contrib_tpu.RunConfig(
master=master,
evaluation_master=master,
model_dir=model_dir,
save_checkpoints_steps=save_checkpoints_steps,
cluster=None,
tf_random_seed=seed,
tpu_config=contrib_tpu.TPUConfig(iterations_per_loop=iterations_per_loop))
def main(argv):
del argv # Unused.
# If using update_damping_immediately resource variables must be enabled.
# (Although they probably will be by default on TPUs.)
if FLAGS.update_damping_immediately:
tf.enable_resource_variables()
tf.set_random_seed(FLAGS.seed)
# Invert using cholesky decomposition + triangular solve. This is the only
# code path for matrix inversion supported on TPU right now.
kfac.utils.set_global_constants(posdef_inv_method='cholesky')
kfac.fisher_factors.set_global_constants(
eigenvalue_decomposition_threshold=10000)
if not FLAGS.use_sua_approx:
if FLAGS.use_custom_patches_op:
kfac.fisher_factors.set_global_constants(
use_patches_second_moment_op=True
)
else:
# Temporary measure to save memory with giant batches:
kfac.fisher_factors.set_global_constants(
sub_sample_inputs=True,
inputs_to_extract_patches_factor=0.1)
config = make_tpu_run_config(
FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop,
FLAGS.save_checkpoints_steps)
estimator = contrib_tpu.TPUEstimator(
use_tpu=True,
model_fn=_model_fn,
config=config,
train_batch_size=FLAGS.batch_size,
eval_batch_size=1024)
estimator.train(
input_fn=mnist_input_fn,
max_steps=FLAGS.train_steps,
hooks=[])
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.app.run(main)
================================================
FILE: kfac/examples/convnet.py
================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train a ConvNet on MNIST using K-FAC.
This library demonstrates how to use K-FAC to train a 5-layer ConvNet on MNIST
using K-FAC.
Note that this example is basically untuned and is not meant to work as an
actual demonstration of the power of the method. It may not even converge. It
merely demonstrates the how to set up K-FAC to run under the various standard
modes of operation in Tensorflow, like SyncReplicas, Estimator, etc.
For an example of the method tuned properly and working well, see for example
the autoencoder_mnist.py example, which replicates the exact experiment from
the original K-FAC paper.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import kfac
import numpy as np
import tensorflow.compat.v1 as tf
from kfac.examples import mnist
__all__ = [
"conv_layer",
"fc_layer",
"max_pool_layer",
"build_model",
"minimize_loss_single_machine",
"distributed_grads_only_and_ops_chief_worker",
"distributed_grads_and_ops_dedicated_workers",
"train_mnist_single_machine",
"train_mnist_distributed_sync_replicas",
"train_mnist_multitower"
]
# Inverse update ops will be run every _INVERT_EVRY iterations.
_INVERT_EVERY = 10
# Covariance matrices will be update _COV_UPDATE_EVERY iterations.
_COV_UPDATE_EVERY = 1
# Displays loss every _REPORT_EVERY iterations.
_REPORT_EVERY = 10
# Use manual registration
_USE_MANUAL_REG = False
def fc_layer(layer_id, inputs, output_size):
"""Builds a fully connected layer.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
to a single example.
output_size: int. Number of output dimensions after fully connected layer.
Returns:
preactivations: Tensor of shape [num_examples, output_size]. Values of the
layer immediately before the activation function.
activations: Tensor of shape [num_examples, output_size]. Values of the
layer immediately after the activation function.
params: Tuple of (weights, bias), parameters for this layer.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
layer = tf.layers.Dense(
output_size,
kernel_initializer=tf.random_normal_initializer(),
name="fc_%d" % layer_id)
preactivations = layer(inputs)
activations = tf.nn.tanh(preactivations)
# layer.weights is a list. This converts it a (hashable) tuple.
return preactivations, activations, (layer.kernel, layer.bias)
def conv_layer(layer_id, inputs, kernel_size, out_channels):
"""Builds a convolutional layer with ReLU non-linearity.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
corresponds to a single example.
kernel_size: int. Width and height of the convolution kernel. The kernel is
assumed to be square.
out_channels: int. Number of output features per pixel.
Returns:
preactivations: Tensor of shape [num_examples, width, height, out_channels].
Values of the layer immediately before the activation function.
activations: Tensor of shape [num_examples, width, height, out_channels].
Values of the layer immediately after the activation function.
params: Tuple of (kernel, bias), parameters for this layer.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
layer = tf.layers.Conv2D(
out_channels,
kernel_size=[kernel_size, kernel_size],
kernel_initializer=tf.random_normal_initializer(stddev=0.01),
padding="SAME",
name="conv_%d" % layer_id)
preactivations = layer(inputs)
activations = tf.nn.relu(preactivations)
# layer.weights is a list. This converts it a (hashable) tuple.
return preactivations, activations, (layer.kernel, layer.bias)
def max_pool_layer(layer_id, inputs, kernel_size, stride):
"""Build a max-pooling layer.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
corresponds to a single example.
kernel_size: int. Width and height to pool over per input channel. The
kernel is assumed to be square.
stride: int. Step size between pooling operations.
Returns:
Tensor of shape [num_examples, width/stride, height/stride, out_channels].
Result of applying max pooling to 'inputs'.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
with tf.variable_scope("pool_%d" % layer_id):
return tf.nn.max_pool(
inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
padding="SAME",
name="pool")
def build_model(examples,
labels,
num_labels,
layer_collection,
register_layers_manually=False):
"""Builds a ConvNet classification model.
Args:
examples: Tensor of shape [num_examples, num_features]. Represents inputs of
model.
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
by softmax for each example.
num_labels: int. Number of distinct values 'labels' can take on.
layer_collection: LayerCollection instance. Layers will be registered here.
register_layers_manually: bool. If True then register the layers with
layer_collection manually. (Default: False)
Returns:
loss: 0-D Tensor representing loss to be minimized.
accuracy: 0-D Tensor representing model's accuracy.
"""
# Build a ConvNet. For each layer with parameters, we'll keep track of the
# preactivations, activations, weights, and bias.
tf.logging.info("Building model.")
pre0, act0, params0 = conv_layer(
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
pre2, act2, params2 = conv_layer(
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
logits, _, params4 = fc_layer(
layer_id=4, inputs=flat_act3, output_size=num_labels)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(tf.cast(labels, dtype=tf.int32),
tf.argmax(logits, axis=1, output_type=tf.int32)),
dtype=tf.float32))
with tf.device("/cpu:0"):
tf.summary.scalar("loss", loss)
tf.summary.scalar("accuracy", accuracy)
layer_collection.register_softmax_cross_entropy_loss(
logits, name="logits")
if register_layers_manually:
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
pre0)
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1,
pre2)
layer_collection.register_fully_connected(params4, flat_act3, logits)
return loss, accuracy
def minimize_loss_single_machine(loss,
accuracy,
layer_collection,
device=None,
session_config=None):
"""Minimize loss with K-FAC on a single machine.
Creates `PeriodicInvCovUpdateKfacOpt` which handles inverse and covariance
computation op placement and execution. A single Session is responsible for
running all of K-FAC's ops. The covariance and inverse update ops are placed
on `device`. All model variables are on CPU.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
device: string or None. The covariance and inverse update ops are run on
this device. If empty or None, the default device will be used.
(Default: None)
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
final value for 'accuracy'.
"""
device_list = [] if not device else [device]
# Train with K-FAC.
g_step = tf.train.get_or_create_global_step()
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
invert_every=_INVERT_EVERY,
cov_update_every=_COV_UPDATE_EVERY,
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
placement_strategy="round_robin",
cov_devices=device_list,
inv_devices=device_list,
trans_devices=device_list,
momentum=0.9)
with tf.device(device):
train_op = optimizer.minimize(loss, global_step=g_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[g_step, loss, accuracy, train_op])
if global_step_ % _REPORT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
return accuracy_
def minimize_loss_single_machine_manual(loss,
accuracy,
layer_collection,
device=None,
session_config=None):
"""Minimize loss with K-FAC on a single machine(Illustrative purpose only).
This function does inverse and covariance computation manually
for illustrative pupose. Check `minimize_loss_single_machine` for
automatic inverse and covariance op placement and execution.
A single Session is responsible for running all of K-FAC's ops. The covariance
and inverse update ops are placed on `device`. All model variables are on CPU.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
device: string or None. The covariance and inverse update ops are run on
this device. If empty or None, the default device will be used.
(Default: None)
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
final value for 'accuracy'.
"""
device_list = [] if not device else [device]
# Train with K-FAC.
g_step = tf.train.get_or_create_global_step()
optimizer = kfac.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
placement_strategy="round_robin",
cov_devices=device_list,
inv_devices=device_list,
trans_devices=device_list,
momentum=0.9)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
lambda: make_update_op(inv_update_thunks), tf.no_op)
with tf.control_dependencies([inverse_op]):
with tf.device(device):
train_op = optimizer.minimize(loss, global_step=g_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[g_step, loss, accuracy, train_op])
if global_step_ % _REPORT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
return accuracy_
def _is_gradient_task(task_id, num_tasks):
"""Returns True if this task should update the weights."""
if num_tasks < 3:
return True
return 0 <= task_id < 0.6 * num_tasks
def _is_cov_update_task(task_id, num_tasks):
"""Returns True if this task should update K-FAC's covariance matrices."""
if num_tasks < 3:
return False
return 0.6 * num_tasks <= task_id < num_tasks - 1
def _is_inv_update_task(task_id, num_tasks):
"""Returns True if this task should update K-FAC's preconditioner."""
if num_tasks < 3:
return False
return task_id == num_tasks - 1
def _num_gradient_tasks(num_tasks):
"""Number of tasks that will update weights."""
if num_tasks < 3:
return num_tasks
return int(np.ceil(0.6 * num_tasks))
def _make_distributed_train_op(
task_id,
num_worker_tasks,
num_ps_tasks,
layer_collection
):
"""Creates optimizer and distributed training op.
Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
the train op.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
Returns:
sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
optimizer.
optimizer: Instance of `KfacOptimizer`.
global_step: `tensor`, Global step.
"""
tf.logging.info("Task id : %d", task_id)
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
global_step = tf.train.get_or_create_global_step()
optimizer = kfac.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
momentum=0.9)
sync_optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
total_num_replicas=num_worker_tasks)
return sync_optimizer, optimizer, global_step
def distributed_grads_only_and_ops_chief_worker(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
loss, accuracy, layer_collection, invert_every=10):
"""Minimize loss with a synchronous implementation of K-FAC.
All workers perform gradient computation. Chief worker applies gradient after
averaging the gradients obtained from all the workers. All workers block
execution until the update is applied. Chief worker runs covariance and
inverse update ops. Covariance and inverse matrices are placed on parameter
servers in a round robin manner. For further details on synchronous
distributed optimization check `tf.train.SyncReplicasOptimizer`.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
master: string. IP and port of TensorFlow runtime process. Set to empty
string to run locally.
checkpoint_dir: string or None. Path to store checkpoints under.
loss: 0-D Tensor. Loss to be minimized.
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
invert_every: `int`, Number of steps between update the inverse.
Returns:
final value for 'accuracy'.
Raises:
ValueError: if task_id >= num_worker_tasks.
"""
sync_optimizer, optimizer, global_step = _make_distributed_train_op(
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
tf.logging.info("Starting training.")
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
if is_chief:
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
tf.equal(tf.mod(global_step, invert_every), 0),
lambda: make_update_op(inv_update_thunks),
tf.no_op)
with tf.control_dependencies([inverse_op]):
train_op = sync_optimizer.minimize(loss, global_step=global_step)
else:
train_op = sync_optimizer.minimize(loss, global_step=global_step)
# Without setting allow_soft_placement=True there will be problems when
# the optimizer tries to place certain ops like "mod" on the GPU (which isn't
# supported).
config = tf.ConfigProto(allow_soft_placement=True)
with tf.train.MonitoredTrainingSession(
master=master,
is_chief=is_chief,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
stop_grace_period_secs=0,
config=config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[global_step, loss, accuracy, train_op])
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
loss_, accuracy_)
return accuracy_
def distributed_grads_and_ops_dedicated_workers(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
loss, accuracy, layer_collection):
"""Minimize loss with a synchronous implementation of K-FAC.
Different workers are responsible for different parts of K-FAC's Ops. The
first 60% of tasks compute gradients; the next 20% accumulate covariance
statistics; the last 20% invert the matrices used to precondition gradients.
The chief worker computes and applies the update.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
master: string. IP and port of TensorFlow runtime process. Set to empty
string to run locally.
checkpoint_dir: string or None. Path to store checkpoints under.
loss: 0-D Tensor. Loss to be minimized.
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
Returns:
final value for 'accuracy'.
Raises:
ValueError: if task_id >= num_worker_tasks.
"""
sync_optimizer, optimizer, global_step = _make_distributed_train_op(
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
_, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
train_op = sync_optimizer.minimize(loss, global_step=global_step)
inv_update_queue = kfac.op_queue.OpQueue(inv_update_ops)
tf.logging.info("Starting training.")
is_chief = (task_id == 0)
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
# Without setting allow_soft_placement=True there will be problems when
# the optimizer tries to place certain ops like "mod" on the GPU (which isn't
# supported).
config = tf.ConfigProto(allow_soft_placement=True)
with tf.train.MonitoredTrainingSession(
master=master,
is_chief=is_chief,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
stop_grace_period_secs=0,
config=config) as sess:
while not sess.should_stop():
# Choose which op this task is responsible for running.
if _is_gradient_task(task_id, num_worker_tasks):
learning_op = train_op
elif _is_cov_update_task(task_id, num_worker_tasks):
learning_op = cov_update_op
elif _is_inv_update_task(task_id, num_worker_tasks):
learning_op = inv_update_queue.next_op(sess)
else:
raise ValueError("Which op should task %d do?" % task_id)
global_step_, loss_, accuracy_, _ = sess.run(
[global_step, loss, accuracy, learning_op])
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
loss_, accuracy_)
return accuracy_
def train_mnist_single_machine(num_epochs,
use_fake_data=False,
device=None,
manual_op_exec=False):
"""Train a ConvNet on MNIST.
Args:
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
device: string or None. The covariance and inverse update ops are run on
this device. If empty or None, the default device will be used.
(Default: None)
manual_op_exec: bool, If `True` then `minimize_loss_single_machine_manual`
is called for training which handles inverse and covariance computation.
This is shown only for illustrative purpose. Otherwise
`minimize_loss_single_machine` is called which relies on
`PeriodicInvCovUpdateOpt` for op placement and execution.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
(examples, labels) = mnist.load_mnist_as_iterator(num_epochs,
128,
use_fake_data=use_fake_data,
flatten_images=False)
# Build a ConvNet.
layer_collection = kfac.LayerCollection()
loss, accuracy = build_model(
examples, labels, num_labels=10, layer_collection=layer_collection,
register_layers_manually=_USE_MANUAL_REG)
if not _USE_MANUAL_REG:
layer_collection.auto_register_layers()
# Without setting allow_soft_placement=True there will be problems when
# the optimizer tries to place certain ops like "mod" on the GPU (which isn't
# supported).
config = tf.ConfigProto(allow_soft_placement=True)
# Fit model.
if manual_op_exec:
return minimize_loss_single_machine_manual(
loss, accuracy, layer_collection, device=device, session_config=config)
else:
return minimize_loss_single_machine(
loss, accuracy, layer_collection, device=device, session_config=config)
def train_mnist_multitower(num_epochs, num_towers,
devices, use_fake_data=False, session_config=None):
"""Train a ConvNet on MNIST.
Training data is split equally among the towers. Each tower computes loss on
its own batch of data and the loss is aggregated on the CPU. The model
variables are placed on first tower. The covariance and inverse update ops
and variables are placed on specified devices in a round robin manner.
Args:
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of towers.
devices: list of strings. List of devices to place the towers.
use_fake_data: bool. If True, generate a synthetic dataset.
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
accuracy of model on the final minibatch of training data.
"""
num_towers = 1 if not devices else len(devices)
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
tower_batch_size = 128
batch_size = tower_batch_size * num_towers
tf.logging.info(
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
(examples, labels) = mnist.load_mnist_as_iterator(num_epochs,
batch_size,
use_fake_data=use_fake_data,
flatten_images=False)
# Split minibatch across towers.
examples = tf.split(examples, num_towers)
labels = tf.split(labels, num_towers)
# Build an MLP. Each tower's layers will be added to the LayerCollection.
layer_collection = kfac.LayerCollection()
tower_results = []
for tower_id in range(num_towers):
with tf.device(devices[tower_id]):
with tf.name_scope("tower%d" % tower_id):
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
tf.logging.info("Building tower %d." % tower_id)
tower_results.append(
build_model(
examples[tower_id],
labels[tower_id],
10,
layer_collection,
register_layers_manually=_USE_MANUAL_REG))
losses, accuracies = zip(*tower_results)
# When using multiple towers we only want to perform automatic
# registation once, after the final tower is made
if not _USE_MANUAL_REG:
layer_collection.auto_register_layers()
# Average across towers.
loss = tf.reduce_mean(losses)
accuracy = tf.reduce_mean(accuracies)
# Fit model.
g_step = tf.train.get_or_create_global_step()
optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
invert_every=_INVERT_EVERY,
cov_update_every=_COV_UPDATE_EVERY,
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
placement_strategy="round_robin",
cov_devices=devices,
inv_devices=devices,
trans_devices=devices,
momentum=0.9)
with tf.device(devices[0]):
train_op = optimizer.minimize(loss, global_step=g_step)
# Without setting allow_soft_placement=True there will be problems when
# the optimizer tries to place certain ops like "mod" on the GPU (which isn't
# supported).
if not session_config:
session_config = tf.ConfigProto(allow_soft_placement=True)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[g_step, loss, accuracy, train_op])
if global_step_ % _REPORT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
def train_mnist_distributed_sync_replicas(task_id,
is_chief,
num_worker_tasks,
num_ps_tasks,
master,
num_epochs,
op_strategy,
use_fake_data=False):
"""Train a ConvNet on MNIST using Sync replicas optimizer.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables.
master: string. IP and port of TensorFlow runtime process.
num_epochs: int. Number of passes to make over the training set.
op_strategy: `string`, Strategy to run the covariance and inverse
ops. If op_strategy == `chief_worker` then covariance and inverse
update ops are run on chief worker otherwise they are run on dedicated
workers.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
Raises:
ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
(examples, labels) = mnist.load_mnist_as_iterator(num_epochs,
128,
use_fake_data=use_fake_data,
flatten_images=False)
# Build a ConvNet.
layer_collection = kfac.LayerCollection()
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
loss, accuracy = build_model(
examples, labels, num_labels=10, layer_collection=layer_collection,
register_layers_manually=_USE_MANUAL_REG)
if not _USE_MANUAL_REG:
layer_collection.auto_register_layers()
# Fit model.
checkpoint_dir = None
if op_strategy == "chief_worker":
return distributed_grads_only_and_ops_chief_worker(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
checkpoint_dir, loss, accuracy, layer_collection)
elif op_strategy == "dedicated_workers":
return distributed_grads_and_ops_dedicated_workers(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
checkpoint_dir, loss, accuracy, layer_collection)
else:
raise ValueError("Only supported op strategies are : {}, {}".format(
"chief_worker", "dedicated_workers"))
def train_mnist_estimator(num_epochs, use_fake_data=False):
"""Train a ConvNet on MNIST using tf.estimator.
Args:
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
def input_fn():
tf.logging.info("Loading MNIST into memory.")
return mnist.load_mnist_as_iterator(num_epochs=num_epochs,
batch_size=64,
flatten_images=False,
use_fake_data=use_fake_data)
def model_fn(features, labels, mode, params):
"""Model function for MLP trained with K-FAC.
Args:
features: Tensor of shape [batch_size, input_size]. Input features.
labels: Tensor of shape [batch_size]. Target labels for training.
mode: tf.estimator.ModeKey. Must be TRAIN.
params: ignored.
Returns:
EstimatorSpec for training.
Raises:
ValueError: If 'mode' is anything other than TRAIN.
"""
del params
if mode != tf.estimator.ModeKeys.TRAIN:
raise ValueError("Only training is supported with this API.")
# Build a ConvNet.
layer_collection = kfac.LayerCollection()
loss, accuracy = build_model(
features, labels, num_labels=10, layer_collection=layer_collection,
register_layers_manually=_USE_MANUAL_REG)
if not _USE_MANUAL_REG:
layer_collection.auto_register_layers()
# Train with K-FAC.
global_step = tf.train.get_or_create_global_step()
optimizer = kfac.KfacOptimizer(
learning_rate=tf.train.exponential_decay(
0.00002, global_step, 10000, 0.5, staircase=True),
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
momentum=0.9)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
def make_batch_executed_op(update_thunks, batch_size=1):
return tf.group(*kfac.utils.batch_execute(
global_step, update_thunks, batch_size=batch_size))
# Run cov_update_op every step. Run 1 inv_update_ops per step.
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
# But make sure to execute all the inverse ops on the first step
inverse_op = tf.cond(tf.equal(global_step, 0),
lambda: make_update_op(inv_update_thunks),
lambda: make_batch_executed_op(inv_update_thunks))
with tf.control_dependencies([inverse_op]):
train_op = optimizer.minimize(loss, global_step=global_step)
# Print metrics every 5 sec.
hooks = [
tf.train.LoggingTensorHook(
{
"loss": loss,
"accuracy": accuracy
}, every_n_secs=5),
]
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
run_config = tf.estimator.RunConfig(
model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
# Train until input_fn() is empty with Estimator. This is a prerequisite for
# TPU compatibility.
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
estimator.train(input_fn=input_fn)
================================================
FILE: kfac/examples/keras/KFAC_vs_Adam_Experiment.md
================================================
# KFAC vs Adam Experiment
## Set Up
We compare KFAC and Adam on a RESNET-20 on the CIFAR10 dataset. We split CIFAR10
into a training (40k), validation (10k), and test (10k) sets. We ran a random
hyperparameter search where the best hyperparameters were chosen by the run that
reaches 89% validation accuracy first in terms of number of steps. We decay both
learning rate and damping/epsilon exponentially. The final learning rate is
fixed at 1e-4, final damping (KFAC) at 1e-6, and final epsilon (Adam) at 1e-8.
Below are the ranges of the tuned hyperparamters. The random search samples all
the hyperparameters from a log uniform scale:
| Hyperparameter | Min | Max |
|---------------------------|------|-------|
| Init Learning Rate | 1e-2 | 10.0 |
| Init Damping (KFAC) | 1e-2 | 100.0 |
| Init Epsilon (Adam) | 1e-4 | 1.0 |
| 1 - Learning Rate Decay | 1e-4 | 0.1 |
| 1 - Damping/Epsilon Decay | 1e-4 | 0.1 |
| 1 - Momentum | 1e-2 | 0.3 |
The initial tuning run was with seed 20190524 with the GPU training script on an
NVIDIA Tesla P100. Then, after choosing the best hyperparameters, we ran those
hyperparameters with the following 10 random seeds: 351515, 382980, 934126,
891369, 64379, 402680, 672242, 421590, 498163, 448799.
# Results
The chosen hyperparameters were the following (to 6 decimal places):
| Hyperparameter | KFAC | Adam |
|---------------------------|----------|----------|
| Init Learning Rate | 0.227214 | 2.242663 |
| Init Damping (KFAC) | 0.288721 | |
| Init Epsilon (Adam) | | 0.183230 |
| 1 - Learning Rate Decay | 0.001090 | 0.000610 |
| 1 - Damping/Epsilon Decay | 0.000287 | 0.000213 |
| 1 - Momentum | 0.018580 | 0.029656 |
## Training Curves
Below are the loss and accuracy training curves with the training and test sets.
The line represents the mean of the 10 seed runs and the coloured region
represents the bootstrapped standard deviation. KFAC reaches 89% validation
accuracy at step 4640 and Adam at step 6560 (measurements were taken every 40
steps).


Among the other runs, KFAC decreases training loss quicker than Adam early in
training, then show similar performance later in training.
## Hyperparameter Analysis
We offer some analysis of the learning rate and damping for KFAC to aid in
choosing appropriate values for these hyperparameters. Plots with the rest of
the hyperparameters for both KFAC and Adam are in the plots folder.

In general, a higher learning rate requires a higher damping. A large learning
rate with low damping leads to divergence, whereas a low learning rate with high
damping leads to SGD-like behaviour, which is suboptimal. The plot above shows
little correlation due to the decay schedules playing a large role, which is
shown below:

A fast damping decay allows for faster training, but can easily lead to
divergence. The best runs are often close to diverging.

As expected, a high learning rate with a low decay can lead to divergence.

Just like with the learning rate and damping, the learning rate decay should
be proportional the damping decay to prevent divergence while training quickly.
================================================
FILE: kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_DDaAex5Q7u-"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "W1dWWdNHQ9L0"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_C170SDp6jBt"
},
"source": [
"# KFAC vs Adam on CIFAR10 on a GPU\n",
"\n",
"This notebook contains the code used to run the experiment comparing KFAC and Adam on CIFAR 10 with a Resnet-20. This was run on a NVIDIA Tesla P100 for the experiment. It can be run on a public GPU colab instance.\n",
"\n",
"[](https://colab.research.google.com/github/tensorflow/kfac/blob/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rw0qz2RWkLeJ"
},
"outputs": [],
"source": [
"!pip install kfac"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LfGyhnaOsgYu"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import math\n",
"import kfac"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DYWIY0C380ye"
},
"outputs": [],
"source": [
"TRAINING_SIZE = 40000\n",
"VALIDATION_SIZE = 10000\n",
"TEST_SIZE = 10000\n",
"SEED = 20190524\n",
"\n",
"num_training_steps = 7500\n",
"batch_size = 1000\n",
"layers = tf.keras.layers\n",
"\n",
"# We take the ceiling because we do not drop the remainder of the batch\n",
"compute_steps_per_epoch = lambda x: int(math.ceil(1. * x / batch_size))\n",
"steps_per_epoch = compute_steps_per_epoch(TRAINING_SIZE)\n",
"val_steps = compute_steps_per_epoch(VALIDATION_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GfeTgsbh5G4g"
},
"outputs": [],
"source": [
"optimizer_name = 'kfac' # 'kfac' or 'adam'\n",
"\n",
"# Best Hyperparameters from the Random Search\n",
"if optimizer_name == 'kfac':\n",
" init_learning_rate = 0.22721400059936694\n",
" final_learning_rate = 1e-04\n",
" init_damping = 0.28872127217018184\n",
" final_damping = 1e-6\n",
" momentum = 1 - 0.018580394981260295\n",
" lr_decay_rate = 1 - 0.001090107322908028\n",
" damping_decay_rate = 1 - 0.0002870880729016523\n",
"elif optimizer_name == 'adam':\n",
" init_learning_rate = 2.24266320779\n",
" final_learning_rate = 1e-4\n",
" init_epsilon = 0.183230038808\n",
" final_epsilon = 1e-8\n",
" momentum = 1 - 0.0296561513388\n",
" lr_decay_rate = 1 - 0.000610416031571\n",
" epsilon_decay_rate = 1 - 0.000212682338199\n",
"else:\n",
" raise ValueError('Ensure optimizer_name is kfac or adam')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "v3vSki-usp9k"
},
"source": [
"## Input Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "D2U3i5kgssy_"
},
"outputs": [],
"source": [
"def _parse_fn(x):\n",
" image, label = x['image'], x['label']\n",
" image = tf.cast(image, tf.float32)\n",
" label = tf.cast(label, tf.int32)\n",
" image = image / 127.5 - 1\n",
" return image, label\n",
"\n",
"\n",
"def _augment_image(image, crop_amount, seed=None):\n",
" # Random Brightness, Contrast, Jpeg Quality, Hue, and Saturation did not\n",
" # seem to work well as augmentations for our training specifications\n",
" input_shape = image.shape.as_list()\n",
" cropped_size = [input_shape[0] - crop_amount,\n",
" input_shape[1] - crop_amount,\n",
" input_shape[2]]\n",
" flipped = tf.image.random_flip_left_right(image, seed)\n",
" cropped = tf.image.random_crop(flipped, cropped_size, seed)\n",
" return tf.image.pad_to_bounding_box(image=cropped,\n",
" offset_height=crop_amount // 2,\n",
" offset_width=crop_amount // 2,\n",
" target_height=input_shape[0],\n",
" target_width=input_shape[1])\n",
"\n",
"\n",
"def _get_raw_data():\n",
" # We split the training data into training and validation ourselves for\n",
" # hyperparameter tuning.\n",
" training_pct = int(100.0 * TRAINING_SIZE / (TRAINING_SIZE + VALIDATION_SIZE))\n",
" train_split = tfds.Split.TRAIN.subsplit(tfds.percent[:training_pct])\n",
" validation_split = tfds.Split.TRAIN.subsplit(tfds.percent[training_pct:])\n",
"\n",
" train_data, info = tfds.load('cifar10:3.*.*', with_info=True, split=train_split)\n",
" val_data = tfds.load('cifar10:3.*.*', split=validation_split)\n",
" test_data = tfds.load('cifar10:3.*.*', split='test')\n",
"\n",
" input_shape = info.features['image'].shape\n",
" num_classes = info.features['label'].num_classes\n",
" info = {'input_shape': input_shape, 'num_classes': num_classes}\n",
" return info, train_data, val_data, test_data\n",
"\n",
"\n",
"def get_input_pipeline(batch_size=None,\n",
" use_augmentation=True,\n",
" seed=None,\n",
" crop_amount=6,\n",
" drop_remainder=False,\n",
" repeat_validation=True):\n",
" \"\"\"Creates CIFAR10 Data Pipeline.\n",
"\n",
" Args:\n",
" batch_size (int): Batch size used for training.\n",
" use_augmentation (bool): If true, applies random horizontal flips and crops\n",
" then pads to images.\n",
" seed (int): Random seed used for augmentation operations.\n",
" crop_amount (int): Number of pixels to crop from the height and width of the\n",
" image. So, the cropped image will be [height - crop_amount, width -\n",
" crop_amount, channels] before it is padded to restore its original size.\n",
" drop_remainder (bool): Whether to drop the remainder of the batch. Needs to\n",
" be true to work on TPUs.\n",
" repeat_validation (bool): Whether to repeat the validation set. Test set is\n",
" never repeated.\n",
"\n",
" Returns:\n",
" A tuple with an info dict (with input_shape (tuple) and number of classes\n",
" (int)) and data dict (train_data (tf.DatasetAdapter), validation_data,\n",
" (tf.DatasetAdapter) and test_data (tf.DatasetAdapter))\n",
" \"\"\"\n",
" info, train_data, val_data, test_data = _get_raw_data()\n",
"\n",
" if not batch_size:\n",
" batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)\n",
"\n",
" train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()\n",
" if use_augmentation:\n",
" train_data = train_data.map(\n",
" lambda x, y: (_augment_image(x, crop_amount, seed), y))\n",
" train_data = train_data.batch(\n",
" min(batch_size, TRAINING_SIZE), drop_remainder=drop_remainder)\n",
" train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"\n",
" val_data = val_data.map(_parse_fn)\n",
" if repeat_validation:\n",
" val_data = val_data.repeat()\n",
" val_data = val_data.batch(\n",
" min(batch_size, VALIDATION_SIZE), drop_remainder=drop_remainder)\n",
" val_data = val_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"\n",
" # Don't repeat test data because it is only used once to evaluate at the end.\n",
" test_data = test_data.map(_parse_fn)\n",
" if repeat_validation:\n",
" test_data = test_data.repeat()\n",
" test_data = test_data.batch(\n",
" min(batch_size, TEST_SIZE), drop_remainder=drop_remainder)\n",
" test_data = test_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"\n",
" data = {'train': train_data, 'validation': val_data, 'test': test_data}\n",
" return data, info"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SLvlpsups2aR"
},
"source": [
"## Model - Resnet V2\n",
"\n",
"Based on https://keras.io/examples/cifar10_resnet/. The only difference is that tf.keras layer implementations are used."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Cch3Ld5Ds4i2"
},
"outputs": [],
"source": [
"def resnet_layer(inputs,\n",
" num_filters=16,\n",
" kernel_size=3,\n",
" strides=1,\n",
" activation='relu',\n",
" batch_normalization=True,\n",
" conv_first=True):\n",
" \"\"\"2D Convolution-Batch Normalization-Activation stack builder.\n",
"\n",
" Based on https://keras.io/examples/cifar10_resnet/.\n",
"\n",
" Args:\n",
" inputs (tensor): input tensor from input image or previous layer\n",
" num_filters (int): Conv2D number of filters\n",
" kernel_size (int): Conv2D square kernel dimensions\n",
" strides (int): Conv2D square stride dimensions\n",
" activation (string): activation name\n",
" batch_normalization (bool): whether to include batch normalization\n",
" conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)\n",
"\n",
" Returns:\n",
" x (tensor): tensor as input to the next layer\n",
" \"\"\"\n",
" conv = layers.Conv2D(num_filters,\n",
" kernel_size=kernel_size,\n",
" strides=strides,\n",
" padding='same',\n",
" kernel_initializer='he_normal',\n",
" kernel_regularizer=tf.keras.regularizers.l2(1e-4))\n",
"\n",
" x = inputs\n",
" if conv_first:\n",
" x = conv(x)\n",
" if batch_normalization:\n",
" x = layers.BatchNormalization()(x)\n",
" if activation is not None:\n",
" x = layers.Activation(activation)(x)\n",
" else:\n",
" if batch_normalization:\n",
" x = layers.BatchNormalization()(x)\n",
" if activation is not None:\n",
" x = layers.Activation(activation)(x)\n",
" x = conv(x)\n",
" return x\n",
"\n",
"\n",
"def resnet_v2(input_shape, depth, num_classes=10):\n",
" \"\"\"ResNet Version 2 Model builder [b].\n",
"\n",
" Based on https://keras.io/examples/cifar10_resnet/.\n",
"\n",
" Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as\n",
" bottleneck layer\n",
" First shortcut connection per layer is 1 x 1 Conv2D.\n",
" Second and onwards shortcut connection is identity.\n",
" At the beginning of each stage, the feature map size is halved (downsampled)\n",
" by a convolutional layer with strides=2, while the number of filter maps is\n",
" doubled. Within each stage, the layers have the same number filters and the\n",
" same filter map sizes.\n",
" Features maps sizes:\n",
" conv1 : 32x32, 16\n",
" stage 0: 32x32, 64\n",
" stage 1: 16x16, 128\n",
" stage 2: 8x8, 256\n",
"\n",
" Args:\n",
" input_shape (tuple/list): shape of input image tensor\n",
" depth (int): number of core convolutional layers\n",
" num_classes (int): number of classes (CIFAR10 has 10)\n",
"\n",
" Returns:\n",
" model (Model): Keras model instance\n",
" \"\"\"\n",
" if (depth - 2) % 9 != 0:\n",
" raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\n",
" # Start model definition.\n",
" num_filters_in = 16\n",
" num_res_blocks = int((depth - 2) / 9)\n",
"\n",
" inputs = tf.keras.Input(shape=input_shape)\n",
" # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\n",
" x = resnet_layer(inputs=inputs, num_filters=num_filters_in, conv_first=True)\n",
"\n",
" # Instantiate the stack of residual units\n",
" for stage in range(3):\n",
" for res_block in range(num_res_blocks):\n",
" activation = 'relu'\n",
" batch_normalization = True\n",
" strides = 1\n",
" if stage == 0:\n",
" num_filters_out = num_filters_in * 4\n",
" if res_block == 0: # first layer and first stage\n",
" activation = None\n",
" batch_normalization = False\n",
" else:\n",
" num_filters_out = num_filters_in * 2\n",
" if res_block == 0: # first layer but not first stage\n",
" strides = 2 # downsample\n",
"\n",
" # bottleneck residual unit\n",
" y = resnet_layer(inputs=x,\n",
" num_filters=num_filters_in,\n",
" kernel_size=1,\n",
" strides=strides,\n",
" activation=activation,\n",
" batch_normalization=batch_normalization,\n",
" conv_first=False)\n",
" y = resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\n",
" y = resnet_layer(inputs=y,\n",
" num_filters=num_filters_out,\n",
" kernel_size=1,\n",
" conv_first=False)\n",
" if res_block == 0:\n",
" # linear projection residual shortcut connection to match\n",
" # changed dims\n",
" x = resnet_layer(inputs=x,\n",
" num_filters=num_filters_out,\n",
" kernel_size=1,\n",
" strides=strides,\n",
" activation=None,\n",
" batch_normalization=False)\n",
" x = layers.Add()([x, y])\n",
"\n",
" num_filters_in = num_filters_out\n",
"\n",
" # Add classifier on top.\n",
" # v2 has BN-ReLU before Pooling\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.Activation('relu')(x)\n",
" x = layers.AveragePooling2D(pool_size=8)(x)\n",
" y = layers.Flatten()(x)\n",
" outputs = layers.Dense(num_classes,\n",
" activation='softmax',\n",
" kernel_initializer='he_normal')(y)\n",
"\n",
" # Instantiate model.\n",
" model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dAUaN-i9tHMY"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Hf5WFHYP8tT9"
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"tf.set_random_seed(SEED)\n",
"\n",
"data, info = get_input_pipeline(batch_size=batch_size,\n",
" seed=SEED,\n",
" repeat_validation=True,\n",
" use_augmentation=True)\n",
"\n",
"model = resnet_v2(input_shape=info['input_shape'],\n",
" depth=20,\n",
" num_classes=info['num_classes'])\n",
"\n",
"loss = 'sparse_categorical_crossentropy'\n",
"\n",
"training_callbacks = [\n",
" kfac.keras.callbacks.ExponentialDecay(hyperparameter='learning_rate',\n",
" init_value=init_learning_rate,\n",
" final_value=final_learning_rate,\n",
" decay_rate=lr_decay_rate)\n",
"]\n",
"\n",
"if optimizer_name == 'kfac':\n",
" opt = kfac.keras.optimizers.Kfac(learning_rate=init_learning_rate,\n",
" damping=init_damping,\n",
" model=model,\n",
" loss=loss,\n",
" momentum=momentum,\n",
" seed=SEED)\n",
" training_callbacks.append(kfac.keras.callbacks.ExponentialDecay(\n",
" hyperparameter='damping',\n",
" init_value=init_damping,\n",
" final_value=final_damping,\n",
" decay_rate=damping_decay_rate))\n",
"\n",
"elif optimizer_name == 'adam':\n",
" opt = tf.keras.optimizers.Adam(learning_rate=init_learning_rate,\n",
" beta_1=momentum,\n",
" epsilon=init_epsilon)\n",
" training_callbacks.append(kfac.keras.callbacks.ExponentialDecay(\n",
" hyperparameter='epsilon',\n",
" init_value=init_epsilon,\n",
" final_value=final_epsilon,\n",
" decay_rate=epsilon_decay_rate))\n",
"\n",
"else:\n",
" raise ValueError('optimizer_name must be \"adam\" or \"kfac\"')\n",
"\n",
"model.compile(loss=loss, optimizer=opt, metrics=['acc'])"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "dD8b27hLy6lO"
},
"outputs": [],
"source": [
"history = model.fit(x=data['train'],\n",
" epochs=num_training_steps//steps_per_epoch,\n",
" steps_per_epoch=steps_per_epoch,\n",
" validation_data=data['validation'],\n",
" validation_steps=val_steps,\n",
" callbacks=training_callbacks)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"_DDaAex5Q7u-"
],
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "KFAC vs Adam on CIFAR10.ipynb",
"provenance": [
{
"file_id": "1pqtoYduODZyJKt4-kwVkt_KtNQCnaNDp",
"timestamp": 1565229994386
}
],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 2",
"name": "python2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_DDaAex5Q7u-"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "W1dWWdNHQ9L0"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KDGkOqGA54FB"
},
"source": [
"# KFAC vs Adam on CIFAR10 on TPUs\n",
"\n",
"This notebook demonstrates how to write a custom training loop with TPU Strategy with KFAC. It can be run on a public TPU colab instance. The key differences between using KFAC with TPU Strategy and using TPU Strategy normally are that the model and optimizer must be created in your train step and KFAC does not work with model.fit (because of the first condition). We also use a batch_size of 1024 instead of 1000 to better utilize TPUs.\n",
"\n",
"[](https://colab.research.google.com/github/tensorflow/kfac/blob/master/kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "At1AvF75kmlr"
},
"outputs": [],
"source": [
"!pip install kfac"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LfGyhnaOsgYu"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import math\n",
"import kfac\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DYWIY0C380ye"
},
"outputs": [],
"source": [
"TRAINING_SIZE = 40000\n",
"VALIDATION_SIZE = 10000\n",
"TEST_SIZE = 10000\n",
"SEED = 20190524\n",
"\n",
"num_training_steps = 7500\n",
"# We use a batch size of 1024 instead 1000 because each TPU core should\n",
"# (ideally) get a batch whose size is a multiple 128 (here we have 8 cores)\n",
"batch_size = 1024\n",
"layers = tf.keras.layers\n",
"\n",
"compute_steps_per_epoch = lambda x: int(math.floor(1. * x / batch_size))\n",
"steps_per_epoch = compute_steps_per_epoch(TRAINING_SIZE)\n",
"val_steps = compute_steps_per_epoch(VALIDATION_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GfeTgsbh5G4g"
},
"outputs": [],
"source": [
"optimizer_name = 'kfac' # 'kfac' or 'adam'\n",
"\n",
"# Best Hyperparameters from the Random Search\n",
"if optimizer_name == 'kfac':\n",
" init_learning_rate = 0.22721400059936694\n",
" final_learning_rate = 1e-04\n",
" init_damping = 0.28872127217018184\n",
" final_damping = 1e-6\n",
" momentum = 1 - 0.018580394981260295\n",
" lr_decay_rate = 1 - 0.001090107322908028\n",
" damping_decay_rate = 1 - 0.0002870880729016523\n",
"elif optimizer_name == 'adam':\n",
" init_learning_rate = 2.24266320779\n",
" final_learning_rate = 1e-4\n",
" init_epsilon = 0.183230038808\n",
" final_epsilon = 1e-8\n",
" momentum = 1 - 0.0296561513388\n",
" lr_decay_rate = 1 - 0.000610416031571\n",
" epsilon_decay_rate = 1 - 0.000212682338199\n",
"else:\n",
" raise ValueError('Ensure optimizer_name is kfac or adam')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "v3vSki-usp9k"
},
"source": [
"## Input Pipeline\n",
"\n",
"The tensorflow_datasets CIFAR10 dataset (used in the GPU notebook) does not work with the public TPUs, because of the way the tf.data.Dataset is downloaded. So, this pipeline uses the tf.keras version, which downloads numpy arrays that we turn into a tf.data.Dataset.\n",
"\n",
"If this pipeline does not work, try using the pipeline in the GPU notebook instead."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "D2U3i5kgssy_"
},
"outputs": [],
"source": [
"def _parse_fn(image, label):\n",
" image = tf.cast(image, tf.float32)\n",
" label = tf.cast(tf.squeeze(label), tf.int32)\n",
" image = image / 127.5 - 1\n",
" return image, label\n",
"\n",
"\n",
"def _augment_image(image, crop_amount, seed=None):\n",
" # Random Brightness, Contrast, Jpeg Quality, Hue, and Saturation did not\n",
" # seem to work well as augmentations for our training specifications\n",
" input_shape = image.shape.as_list()\n",
" cropped_size = [input_shape[0] - crop_amount,\n",
" input_shape[1] - crop_amount,\n",
" input_shape[2]]\n",
" flipped = tf.image.random_flip_left_right(image, seed)\n",
" cropped = tf.image.random_crop(flipped, cropped_size, seed)\n",
" return tf.image.pad_to_bounding_box(image=cropped,\n",
" offset_height=crop_amount // 2,\n",
" offset_width=crop_amount // 2,\n",
" target_height=input_shape[0],\n",
" target_width=input_shape[1])\n",
"\n",
"\n",
"def _get_raw_data():\n",
" # We split the training data into training and validation ourselves for\n",
" # hyperparameter tuning.\n",
" train_and_val, test = tf.keras.datasets.cifar10.load_data()\n",
" train = (train_and_val[0][:TRAINING_SIZE], train_and_val[1][:TRAINING_SIZE])\n",
" val = (train_and_val[0][TRAINING_SIZE:], train_and_val[1][TRAINING_SIZE:])\n",
" info = {'input_shape':train_and_val[0].shape[1:], 'num_classes':10}\n",
" return (info,\n",
" tf.data.Dataset.from_tensor_slices(train),\n",
" tf.data.Dataset.from_tensor_slices(val),\n",
" tf.data.Dataset.from_tensor_slices(test))\n",
"\n",
"\n",
"def get_input_pipeline(batch_size=None,\n",
" use_augmentation=True,\n",
" seed=None,\n",
" crop_amount=6,\n",
" drop_remainder=False,\n",
" repeat_validation=True):\n",
" \"\"\"Creates CIFAR10 Data Pipeline.\n",
"\n",
" Args:\n",
" batch_size (int): Batch size used for training.\n",
" use_augmentation (bool): If true, applies random horizontal flips and crops\n",
" then pads to images.\n",
" seed (int): Random seed used for augmentation operations.\n",
" crop_amount (int): Number of pixels to crop from the height and width of the\n",
" image. So, the cropped image will be [height - crop_amount, width -\n",
" crop_amount, channels] before it is padded to restore its original size.\n",
" drop_remainder (bool): Whether to drop the remainder of the batch. Needs to\n",
" be true to work on TPUs.\n",
" repeat_validation (bool): Whether to repeat the validation set. Test set is\n",
" never repeated.\n",
"\n",
" Returns:\n",
" A tuple with an info dict (with input_shape (tuple) and number of classes\n",
" (int)) and data dict (train_data (tf.DatasetAdapter), validation_data,\n",
" (tf.DatasetAdapter) and test_data (tf.DatasetAdapter))\n",
" \"\"\"\n",
" info, train_data, val_data, test_data = _get_raw_data()\n",
"\n",
" if not batch_size:\n",
" batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)\n",
"\n",
" train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()\n",
" if use_augmentation:\n",
" train_data = train_data.map(\n",
" lambda x, y: (_augment_image(x, crop_amount, seed), y))\n",
" train_data = train_data.batch(\n",
" min(batch_size, TRAINING_SIZE), drop_remainder=drop_remainder)\n",
" train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"\n",
" val_data = val_data.map(_parse_fn)\n",
" if repeat_validation:\n",
" val_data = val_data.repeat()\n",
" val_data = val_data.batch(\n",
" min(batch_size, VALIDATION_SIZE), drop_remainder=drop_remainder)\n",
" val_data = val_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"\n",
" # Don't repeat test data because it is only used once to evaluate at the end.\n",
" test_data = test_data.map(_parse_fn)\n",
" if repeat_validation:\n",
" test_data = test_data.repeat()\n",
" test_data = test_data.batch(\n",
" min(batch_size, TEST_SIZE), drop_remainder=drop_remainder)\n",
" test_data = test_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"\n",
" data = {'train': train_data, 'validation': val_data, 'test': test_data}\n",
" return data, info"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SLvlpsups2aR"
},
"source": [
"## Model - Resnet V2\n",
"\n",
"Based on https://keras.io/examples/cifar10_resnet/. The only difference is that tf.keras layer implementations are used, the model outputs logits instead of a probability distribution, and an input tensor is used instead of the input shape (since TPUs don't support placeholders)."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Cch3Ld5Ds4i2"
},
"outputs": [],
"source": [
"def resnet_layer(inputs,\n",
" num_filters=16,\n",
" kernel_size=3,\n",
" strides=1,\n",
" activation='relu',\n",
" batch_normalization=True,\n",
" conv_first=True):\n",
" \"\"\"2D Convolution-Batch Normalization-Activation stack builder.\n",
"\n",
" Based on https://keras.io/examples/cifar10_resnet/.\n",
"\n",
" Args:\n",
" inputs (tensor): input tensor from input image or previous layer\n",
" num_filters (int): Conv2D number of filters\n",
" kernel_size (int): Conv2D square kernel dimensions\n",
" strides (int): Conv2D square stride dimensions\n",
" activation (string): activation name\n",
" batch_normalization (bool): whether to include batch normalization\n",
" conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)\n",
"\n",
" Returns:\n",
" x (tensor): tensor as input to the next layer\n",
" \"\"\"\n",
" conv = layers.Conv2D(num_filters,\n",
" kernel_size=kernel_size,\n",
" strides=strides,\n",
" padding='same',\n",
" kernel_initializer='he_normal',\n",
" kernel_regularizer=tf.keras.regularizers.l2(1e-4))\n",
"\n",
" x = inputs\n",
" if conv_first:\n",
" x = conv(x)\n",
" if batch_normalization:\n",
" x = layers.BatchNormalization()(x)\n",
" if activation is not None:\n",
" x = layers.Activation(activation)(x)\n",
" else:\n",
" if batch_normalization:\n",
" x = layers.BatchNormalization()(x)\n",
" if activation is not None:\n",
" x = layers.Activation(activation)(x)\n",
" x = conv(x)\n",
" return x\n",
"\n",
"\n",
"def resnet_v2(input_tensor, depth, num_classes=10):\n",
" \"\"\"ResNet Version 2 Model builder [b].\n",
"\n",
" Based on https://keras.io/examples/cifar10_resnet/.\n",
"\n",
" Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as\n",
" bottleneck layer\n",
" First shortcut connection per layer is 1 x 1 Conv2D.\n",
" Second and onwards shortcut connection is identity.\n",
" At the beginning of each stage, the feature map size is halved (downsampled)\n",
" by a convolutional layer with strides=2, while the number of filter maps is\n",
" doubled. Within each stage, the layers have the same number filters and the\n",
" same filter map sizes.\n",
" Features maps sizes:\n",
" conv1 : 32x32, 16\n",
" stage 0: 32x32, 64\n",
" stage 1: 16x16, 128\n",
" stage 2: 8x8, 256\n",
"\n",
" Args:\n",
" input_shape (tuple/list): shape of input image tensor\n",
" depth (int): number of core convolutional layers\n",
" num_classes (int): number of classes (CIFAR10 has 10)\n",
"\n",
" Returns:\n",
" model (Model): Keras model instance\n",
" \"\"\"\n",
" if (depth - 2) % 9 != 0:\n",
" raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')\n",
" # Start model definition.\n",
" num_filters_in = 16\n",
" num_res_blocks = int((depth - 2) / 9)\n",
"\n",
" inputs = tf.keras.Input(tensor=input_tensor)\n",
" # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths\n",
" x = resnet_layer(inputs=inputs, num_filters=num_filters_in, conv_first=True)\n",
"\n",
" # Instantiate the stack of residual units\n",
" for stage in range(3):\n",
" for res_block in range(num_res_blocks):\n",
" activation = 'relu'\n",
" batch_normalization = True\n",
" strides = 1\n",
" if stage == 0:\n",
" num_filters_out = num_filters_in * 4\n",
" if res_block == 0: # first layer and first stage\n",
" activation = None\n",
" batch_normalization = False\n",
" else:\n",
" num_filters_out = num_filters_in * 2\n",
" if res_block == 0: # first layer but not first stage\n",
" strides = 2 # downsample\n",
"\n",
" # bottleneck residual unit\n",
" y = resnet_layer(inputs=x,\n",
" num_filters=num_filters_in,\n",
" kernel_size=1,\n",
" strides=strides,\n",
" activation=activation,\n",
" batch_normalization=batch_normalization,\n",
" conv_first=False)\n",
" y = resnet_layer(inputs=y, num_filters=num_filters_in, conv_first=False)\n",
" y = resnet_layer(inputs=y,\n",
" num_filters=num_filters_out,\n",
" kernel_size=1,\n",
" conv_first=False)\n",
" if res_block == 0:\n",
" # linear projection residual shortcut connection to match\n",
" # changed dims\n",
" x = resnet_layer(inputs=x,\n",
" num_filters=num_filters_out,\n",
" kernel_size=1,\n",
" strides=strides,\n",
" activation=None,\n",
" batch_normalization=False)\n",
" x = layers.Add()([x, y])\n",
"\n",
" num_filters_in = num_filters_out\n",
"\n",
" # Add classifier on top.\n",
" # v2 has BN-ReLU before Pooling\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.Activation('relu')(x)\n",
" x = layers.AveragePooling2D(pool_size=8)(x)\n",
" y = layers.Flatten()(x)\n",
" outputs = layers.Dense(num_classes,\n",
" activation='softmax',\n",
" kernel_initializer='he_normal')(y)\n",
"\n",
" # Instantiate model.\n",
" model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hpSh8fWKiWO7"
},
"source": [
"## TPU Set Up"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "c_JSG9X7iaxu"
},
"outputs": [],
"source": [
"def get_tpu_address():\n",
" if 'TPU_NAME' in os.environ and 'COLAB_TPU_ADDR' in os.environ: # public colab\n",
" assert os.environ['COLAB_GPU'] == '0'\n",
" TPU_ADDRESS = os.environ['TPU_NAME']\n",
" from google.colab import auth\n",
" auth.authenticate_user()\n",
" print('Running on public colab https://colab.research.google.com')\n",
" elif 'TPU_NAME' in os.environ and not 'COLAB_TPU_ADDR' in os.environ: # Cloud TPU\n",
" TPU_ADDRESS = os.environ['TPU_NAME']\n",
" print('Running on Cloud TPU')\n",
" else:\n",
" raise ValueError('Unknown environment')\n",
" return TPU_ADDRESS"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5LaK9He6Yw4B"
},
"outputs": [],
"source": [
"cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(\n",
" tpu=get_tpu_address())\n",
"tf.tpu.experimental.initialize_tpu_system(cluster_resolver)\n",
"tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dAUaN-i9tHMY"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ClkFCVbDp0HK"
},
"outputs": [],
"source": [
"def get_optimizer(model, loss, global_step):\n",
" decayed_learning_rate = tf.train.exponential_decay(init_learning_rate,\n",
" global_step=global_step,\n",
" decay_rate=lr_decay_rate,\n",
" decay_steps=1)\n",
" learning_rate = tf.maximum(decayed_learning_rate, final_learning_rate)\n",
"\n",
" if optimizer_name == 'kfac':\n",
" decayed_damping = tf.train.exponential_decay(init_damping,\n",
" global_step=global_step,\n",
" decay_rate=damping_decay_rate,\n",
" decay_steps=1)\n",
" damping = tf.maximum(decayed_damping, final_damping)\n",
" # We cannot use the Keras version because Keras optimizers do not support\n",
" # a global_step argument for minimize. Instead, we use the Keras automated\n",
" # layed collection functionality to get our layer collection.\n",
" lc = kfac.keras.utils.get_layer_collection(\n",
" model=model, loss=loss, seed=SEED)\n",
" optimizer = kfac.PeriodicInvCovUpdateKfacOpt(\n",
" learning_rate=learning_rate,\n",
" damping=damping,\n",
" momentum=momentum,\n",
" layer_collection=lc,\n",
" # Replica round robin places each inverse operations on a different \n",
" # replica (TPU core) so that each inverse is computed on one replica\n",
" # then the replicas are synced.\n",
" placement_strategy='replica_round_robin')\n",
"\n",
gitextract_xp2vjhqo/ ├── .travis.yml ├── AUTHORS ├── LICENSE ├── README.md ├── docs/ │ ├── applications.md │ ├── contact.md │ ├── examples/ │ │ ├── auto_damp.md │ │ ├── convolutional.md │ │ ├── distributed_training.md │ │ └── parameters.md │ ├── index.md │ ├── papers.md │ └── sitemap.md ├── kfac/ │ ├── __init__.py │ ├── examples/ │ │ ├── __init__.py │ │ ├── autoencoder_mnist.py │ │ ├── autoencoder_mnist_tpu_estimator.py │ │ ├── autoencoder_mnist_tpu_strategy.py │ │ ├── classifier_mnist.py │ │ ├── classifier_mnist_tpu_estimator.py │ │ ├── convnet.py │ │ ├── keras/ │ │ │ ├── KFAC_vs_Adam_Experiment.md │ │ │ ├── KFAC_vs_Adam_on_CIFAR10.ipynb │ │ │ └── KFAC_vs_Adam_on_CIFAR10_TPU.ipynb │ │ ├── mnist.py │ │ └── rnn_mnist.py │ └── python/ │ ├── __init__.py │ ├── keras/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── optimizers.py │ │ ├── saving_utils.py │ │ └── utils.py │ ├── kernel_tests/ │ │ ├── data_reader_test.py │ │ ├── estimator_test.py │ │ ├── graph_search_test.py │ │ ├── keras_callbacks_test.py │ │ ├── keras_optimizers_test.py │ │ ├── keras_saving_utils_test.py │ │ ├── keras_utils_test.py │ │ ├── layer_collection_test.py │ │ ├── loss_functions_test.py │ │ ├── op_queue_test.py │ │ ├── optimizer_test.py │ │ ├── periodic_inv_cov_update_kfac_opt_test.py │ │ └── utils_test.py │ └── ops/ │ ├── __init__.py │ ├── curvature_matrix_vector_products.py │ ├── estimator.py │ ├── fisher_blocks.py │ ├── fisher_factors.py │ ├── kfac_utils/ │ │ ├── __init__.py │ │ ├── async_inv_cov_update_kfac_opt.py │ │ ├── data_reader.py │ │ ├── data_reader_alt.py │ │ └── periodic_inv_cov_update_kfac_opt.py │ ├── layer_collection.py │ ├── linear_operator.py │ ├── loss_functions.py │ ├── op_queue.py │ ├── optimizer.py │ ├── placement.py │ ├── tensormatch/ │ │ ├── __init__.py │ │ ├── graph_matcher.py │ │ ├── graph_patterns.py │ │ ├── graph_search.py │ │ └── tensorflow_graph_util.py │ └── utils.py └── setup.py
SYMBOL INDEX (1212 symbols across 44 files)
FILE: kfac/examples/autoencoder_mnist.py
function make_train_op (line 143) | def make_train_op(minibatch,
class AutoEncoder (line 243) | class AutoEncoder(snt.AbstractModule):
method __init__ (line 246) | def __init__(self,
method _build (line 277) | def _build(self, inputs):
class MLPManualReg (line 284) | class MLPManualReg(snt.AbstractModule):
method __init__ (line 286) | def __init__(self,
method _build (line 310) | def _build(self, inputs, layer_collection=None):
class AutoEncoderManualReg (line 330) | class AutoEncoderManualReg(snt.AbstractModule):
method __init__ (line 333) | def __init__(self,
method _build (line 365) | def _build(self, inputs, layer_collection=None):
function get_keras_autoencoder (line 372) | def get_keras_autoencoder(**input_kwargs):
function compute_squared_error (line 426) | def compute_squared_error(logits, targets):
function compute_loss (line 432) | def compute_loss(logits=None,
function load_mnist (line 455) | def load_mnist():
function _get_batch_size_schedule (line 493) | def _get_batch_size_schedule(minibatch_maxsize):
function construct_train_quants (line 506) | def construct_train_quants():
function main (line 567) | def main(_):
FILE: kfac/examples/autoencoder_mnist_tpu_estimator.py
function make_train_op (line 58) | def make_train_op(minibatch,
function compute_squared_error (line 96) | def compute_squared_error(logits, targets):
function compute_loss (line 102) | def compute_loss(logits, labels):
function mnist_input_fn (line 113) | def mnist_input_fn(params):
function print_tensors (line 125) | def print_tensors(**tensors):
function _model_fn (line 138) | def _model_fn(features, labels, mode, params):
function make_tpu_run_config (line 212) | def make_tpu_run_config(master, seed, model_dir, iterations_per_loop,
function main (line 224) | def main(argv):
FILE: kfac/examples/autoencoder_mnist_tpu_strategy.py
function make_train_op (line 62) | def make_train_op(minibatch,
function compute_squared_error (line 100) | def compute_squared_error(logits, targets):
function compute_loss (line 106) | def compute_loss(logits, labels, model):
function mnist_input_fn (line 115) | def mnist_input_fn(batch_size):
function _train_step (line 127) | def _train_step(batch):
function train (line 181) | def train():
function main (line 212) | def main(argv):
FILE: kfac/examples/classifier_mnist.py
class Model (line 141) | class Model(snt.AbstractModule):
method _build (line 144) | def _build(self, inputs):
function make_train_op (line 173) | def make_train_op(minibatch,
function compute_loss (line 272) | def compute_loss(logits=None,
function load_mnist (line 296) | def load_mnist():
function _get_batch_size_schedule (line 334) | def _get_batch_size_schedule(num_examples):
function group_assign (line 349) | def group_assign(dest, source):
function make_eval_ops (line 353) | def make_eval_ops(train_vars, ema):
function construct_train_quants (line 383) | def construct_train_quants():
function main (line 448) | def main(_):
FILE: kfac/examples/classifier_mnist_tpu_estimator.py
function make_train_op (line 58) | def make_train_op(minibatch,
function mnist_input_fn (line 96) | def mnist_input_fn(params):
function print_tensors (line 108) | def print_tensors(**tensors):
function _model_fn (line 121) | def _model_fn(features, labels, mode, params):
function make_tpu_run_config (line 193) | def make_tpu_run_config(master, seed, model_dir, iterations_per_loop,
function main (line 205) | def main(argv):
FILE: kfac/examples/convnet.py
function fc_layer (line 69) | def fc_layer(layer_id, inputs, output_size):
function conv_layer (line 97) | def conv_layer(layer_id, inputs, kernel_size, out_channels):
function max_pool_layer (line 129) | def max_pool_layer(layer_id, inputs, kernel_size, stride):
function build_model (line 152) | def build_model(examples,
function minimize_loss_single_machine (line 210) | def minimize_loss_single_machine(loss,
function minimize_loss_single_machine_manual (line 268) | def minimize_loss_single_machine_manual(loss,
function _is_gradient_task (line 337) | def _is_gradient_task(task_id, num_tasks):
function _is_cov_update_task (line 344) | def _is_cov_update_task(task_id, num_tasks):
function _is_inv_update_task (line 351) | def _is_inv_update_task(task_id, num_tasks):
function _num_gradient_tasks (line 358) | def _num_gradient_tasks(num_tasks):
function _make_distributed_train_op (line 365) | def _make_distributed_train_op(
function distributed_grads_only_and_ops_chief_worker (line 406) | def distributed_grads_only_and_ops_chief_worker(
function distributed_grads_and_ops_dedicated_workers (line 485) | def distributed_grads_and_ops_dedicated_workers(
function train_mnist_single_machine (line 557) | def train_mnist_single_machine(num_epochs,
function train_mnist_multitower (line 607) | def train_mnist_multitower(num_epochs, num_towers,
function train_mnist_distributed_sync_replicas (line 703) | def train_mnist_distributed_sync_replicas(task_id,
function train_mnist_estimator (line 764) | def train_mnist_estimator(num_epochs, use_fake_data=False):
FILE: kfac/examples/mnist.py
function load_mnist_as_tensors (line 32) | def load_mnist_as_tensors(flatten_images=True, dtype=tf.float32):
function load_mnist_as_dataset (line 74) | def load_mnist_as_dataset(flatten_images=True):
function load_mnist_as_iterator (line 92) | def load_mnist_as_iterator(num_epochs, batch_size,
FILE: kfac/examples/rnn_mnist.py
function make_train_op (line 96) | def make_train_op(batch_size,
function eval_model (line 166) | def eval_model(x, num_classes, layer_collection=None):
function compute_loss (line 252) | def compute_loss(inputs, labels, num_classes, layer_collection=None):
function load_mnist (line 274) | def load_mnist():
function main (line 312) | def main(_):
FILE: kfac/python/keras/callbacks.py
class HyperparameterDecay (line 27) | class HyperparameterDecay(tf.keras.callbacks.Callback):
method __init__ (line 30) | def __init__(self, hyperparameter, num_delay_steps=0, verbose=0):
method on_train_begin (line 45) | def on_train_begin(self, logs=None):
method on_epoch_begin (line 53) | def on_epoch_begin(self, epoch, logs=None):
method on_epoch_end (line 60) | def on_epoch_end(self, epoch, logs=None):
method _get_global_step (line 65) | def _get_global_step(self):
class PolynomialDecay (line 70) | class PolynomialDecay(HyperparameterDecay):
method __init__ (line 83) | def __init__(self,
method on_batch_begin (line 107) | def on_batch_begin(self, batch, logs=None):
class ExponentialDecay (line 116) | class ExponentialDecay(HyperparameterDecay):
method __init__ (line 137) | def __init__(self,
method on_batch_begin (line 181) | def on_batch_begin(self, batch, logs=None):
FILE: kfac/python/keras/optimizers.py
function _configure_kfac_kwargs_for_adaptive (line 53) | def _configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive):
class Kfac (line 105) | class Kfac(tf.keras.optimizers.Optimizer):
method __init__ (line 108) | def __init__(self, # pylint: disable=invalid-name
method name (line 261) | def name(self):
method name (line 266) | def name(self, value):
method optimizer (line 276) | def optimizer(self):
method layers (line 291) | def layers(self):
method mutable_hyperparameters (line 295) | def mutable_hyperparameters(self):
method register_layers (line 298) | def register_layers(self, model=None, loss=None, layer_collection=None):
method register_train_batch (line 310) | def register_train_batch(self, train_batch, batch_size=None):
method minimize (line 330) | def minimize(self, loss, var_list, grad_loss=None, name=None):
method apply_gradients (line 337) | def apply_gradients(self, grads_and_vars, name=None):
method get_updates (line 341) | def get_updates(self, loss, params):
method get_config (line 344) | def get_config(self):
method _create_optimizer (line 350) | def _create_optimizer(self):
method _call_and_track_vars (line 402) | def _call_and_track_vars(self, method_name, *args, **kwargs):
method _set_hyper (line 422) | def _set_hyper(self, name, value):
FILE: kfac/python/keras/saving_utils.py
function _compile_args_from_training_config (line 38) | def _compile_args_from_training_config(training_config, custom_objects=N...
function load_model (line 69) | def load_model(filepath, custom_objects=None, optimizer_name=None):
FILE: kfac/python/keras/utils.py
function get_parent (line 48) | def get_parent(node):
function serialize_loss (line 70) | def serialize_loss(loss):
function serialize_fisher_approx (line 85) | def serialize_fisher_approx(fisher_approx):
function _get_verified_dict (line 96) | def _get_verified_dict(container, container_name, layer_names):
function register_layer (line 119) | def register_layer(layer_collection, layer, fisher_approx=None, **kwargs):
function register_loss (line 227) | def register_loss(layer_collection, layer, loss, **kwargs):
function get_layer_collection (line 276) | def get_layer_collection(model,
function get_loss_fn (line 400) | def get_loss_fn(model,
FILE: kfac/python/kernel_tests/data_reader_test.py
class DataReaderTest (line 28) | class DataReaderTest(tf.test.TestCase):
method test_read_batch (line 30) | def test_read_batch(self):
method test_cached_batch (line 50) | def test_cached_batch(self):
FILE: kfac/python/kernel_tests/estimator_test.py
class EstimatorTest (line 39) | class EstimatorTest(tf.test.TestCase):
method setUp (line 41) | def setUp(self):
method testEstimatorInitManualRegistration (line 61) | def testEstimatorInitManualRegistration(self):
method testVariableWrongNumberOfUses (line 93) | def testVariableWrongNumberOfUses(self, mock_uses):
method testInvalidEstimationMode (line 102) | def testInvalidEstimationMode(self):
method testGradientsModeBuild (line 112) | def testGradientsModeBuild(self):
method testEmpiricalModeBuild (line 122) | def testEmpiricalModeBuild(self):
method testCurvaturePropModeBuild (line 132) | def testCurvaturePropModeBuild(self):
method testExactModeBuild (line 142) | def testExactModeBuild(self):
method test_cov_update_thunks (line 152) | def test_cov_update_thunks(self):
method test_round_robin_placement (line 201) | def test_round_robin_placement(self):
method test_inv_update_thunks (line 237) | def test_inv_update_thunks(self):
FILE: kfac/python/kernel_tests/graph_search_test.py
function _build_model (line 32) | def _build_model():
function _build_mock_records (line 64) | def _build_mock_records():
function assert_fisher_blocks_match (line 101) | def assert_fisher_blocks_match(test_case, layer_collection_a,
function sparse_softmax_cross_entropy (line 121) | def sparse_softmax_cross_entropy(labels,
class GraphSearchTestCase (line 142) | class GraphSearchTestCase(tf.test.TestCase):
method testRegisterLayers (line 144) | def testRegisterLayers(self):
method test_register_records_order (line 175) | def test_register_records_order(self):
method test_multitower_examples_model (line 215) | def test_multitower_examples_model(self):
method test_multitower_multi_loss_function (line 297) | def test_multitower_multi_loss_function(self):
method test_filter_user_registered_records (line 335) | def test_filter_user_registered_records(self):
method test_filter_grouped_variable_records (line 356) | def test_filter_grouped_variable_records(self):
method test_filter_subgraph_records (line 383) | def test_filter_subgraph_records(self):
method test_rnn_multi (line 394) | def test_rnn_multi(self):
method test_graph_search_match_fail (line 484) | def test_graph_search_match_fail(self):
method test_specify_approximation (line 515) | def test_specify_approximation(self):
method test_specify_approximation_shared_parameters (line 580) | def test_specify_approximation_shared_parameters(self):
method test_tied_weights_untied_bias_registered_weights (line 613) | def test_tied_weights_untied_bias_registered_weights(self):
method test_tied_weights_untied_bias_registered_affine (line 641) | def test_tied_weights_untied_bias_registered_affine(self):
method test_tied_weights_untied_bias (line 675) | def test_tied_weights_untied_bias(self):
method test_tied_weights_untied_bias_registered_bias (line 694) | def test_tied_weights_untied_bias_registered_bias(self):
method test_multi_time_batch_fold (line 714) | def test_multi_time_batch_fold(self):
method test_multiple_weights (line 752) | def test_multiple_weights(self):
method test_subset_weights_manual_registration (line 783) | def test_subset_weights_manual_registration(self):
method mixed_usage_test (line 818) | def mixed_usage_test(self):
method test_resource_variable (line 842) | def test_resource_variable(self):
FILE: kfac/python/kernel_tests/keras_callbacks_test.py
class HyperParamTracker (line 32) | class HyperParamTracker(tf.keras.callbacks.Callback):
method __init__ (line 35) | def __init__(self, hyper, record_list, frequency):
method on_batch_end (line 40) | def on_batch_end(self, batch, logs=None):
method on_epoch_end (line 46) | def on_epoch_end(self, epoch, logs=None):
class CallbacksTest (line 53) | class CallbacksTest(parameterized.TestCase, tf.test.TestCase):
method __init__ (line 55) | def __init__(self, *args, **kwargs):
method setUp (line 62) | def setUp(self):
method testPolynomialDecayValues (line 67) | def testPolynomialDecayValues(self):
method testExponentialDampingValuesWithDecayRate (line 95) | def testExponentialDampingValuesWithDecayRate(self):
method testExponentialDampingValuesWithFinalValue (line 122) | def testExponentialDampingValuesWithFinalValue(self):
method testExponentialDampingValuesWithFinalValueAndRate (line 151) | def testExponentialDampingValuesWithFinalValueAndRate(self):
method testTrainHistory (line 191) | def testTrainHistory(self, hyper, callback):
method testDampingDecayFailsWithNoDamping (line 201) | def testDampingDecayFailsWithNoDamping(self):
method testExponentialDampingFailsNoRateOrFinalValue (line 210) | def testExponentialDampingFailsNoRateOrFinalValue(self):
method testExponentialDampingFailsWithAllOptionals (line 215) | def testExponentialDampingFailsWithAllOptionals(self):
FILE: kfac/python/kernel_tests/keras_optimizers_test.py
function _get_synthetic_mnist_dataset (line 37) | def _get_synthetic_mnist_dataset(train_size=64, test_size=16):
function _get_synthetic_mnist_train_tensors (line 52) | def _get_synthetic_mnist_train_tensors(
function _generate_target_fn (line 60) | def _generate_target_fn(num_examples):
function _generate_regression_data (line 83) | def _generate_regression_data(num_eg, num_train_eg):
function _simple_mlp (line 99) | def _simple_mlp():
function _mnist_model (line 107) | def _mnist_model(use_bias=True, use_separate_activation=True):
function _train_model (line 151) | def _train_model(data,
class KfacOptimizerTest (line 185) | class KfacOptimizerTest(parameterized.TestCase, tf.test.TestCase):
method __init__ (line 187) | def __init__(self, *args, **kwargs):
method setUp (line 191) | def setUp(self):
method testFunctionalInstantiation (line 196) | def testFunctionalInstantiation(self):
method testSequentialInstantiation (line 204) | def testSequentialInstantiation(self):
method testInstantiationWithLayerCollection (line 215) | def testInstantiationWithLayerCollection(self):
method testRNNFails (line 223) | def testRNNFails(self):
method testBiasAndActivations (line 238) | def testBiasAndActivations(self, use_bias, use_separate_activation):
method testRegression (line 243) | def testRegression(self):
method testClipNormFails (line 249) | def testClipNormFails(self):
method testClipValueFails (line 254) | def testClipValueFails(self):
method testLossTensor (line 259) | def testLossTensor(self):
method testArgsKwargs (line 266) | def testArgsKwargs(self):
method testConfig (line 305) | def testConfig(self):
method testFromConfig (line 333) | def testFromConfig(self, kwargs_updates):
method testGettingHyper (line 362) | def testGettingHyper(self, hyper_ctor):
method testGettingVariableHyperFails (line 376) | def testGettingVariableHyperFails(self):
method testSetTFVariableHyper (line 389) | def testSetTFVariableHyper(self, name, val):
method testSetFloatHyper (line 409) | def testSetFloatHyper(self, name, val):
method testModifyingTensorHypersFails (line 429) | def testModifyingTensorHypersFails(self, name, val):
method testLRBackwardsCompatibility (line 440) | def testLRBackwardsCompatibility(self):
method testMultipleLossTraining (line 456) | def testMultipleLossTraining(self):
method testRegisterLayersWithModel (line 481) | def testRegisterLayersWithModel(self, loss):
method testRegisterLayersWithLayerCollection (line 488) | def testRegisterLayersWithLayerCollection(self):
method testRegisterLayersCompiledModel (line 498) | def testRegisterLayersCompiledModel(self, loss):
method testTrainWithoutCreatingOptimizerFails (line 506) | def testTrainWithoutCreatingOptimizerFails(self):
method testEmptyCreateKfacOptimizerFails (line 514) | def testEmptyCreateKfacOptimizerFails(self):
method testSeed (line 519) | def testSeed(self):
method testNewOptSameVarScope (line 525) | def testNewOptSameVarScope(self):
method testGetSetWeights (line 534) | def testGetSetWeights(self):
method testTrainModelWithNormalization (line 577) | def testTrainModelWithNormalization(self, has_shift):
method testTrainModelWithFusedBN (line 595) | def testTrainModelWithFusedBN(self, has_shift):
method testTrainModelWithFusedBNAndLearningPhase (line 610) | def testTrainModelWithFusedBNAndLearningPhase(self, has_shift):
method testCustomTrainingLoopSequential (line 627) | def testCustomTrainingLoopSequential(self, input_conv_kwargs):
method testCustomTrainingLoopFunctionalInpTensor (line 650) | def testCustomTrainingLoopFunctionalInpTensor(self):
method testCustomTrainingLoopFunctionalInpShape (line 673) | def testCustomTrainingLoopFunctionalInpShape(self):
method testCustomTrainingLoopMakeOptimizerBeforeModelCall (line 699) | def testCustomTrainingLoopMakeOptimizerBeforeModelCall(self):
method testCustomTrainingUnwrappedTensorFails (line 723) | def testCustomTrainingUnwrappedTensorFails(self):
method testTrainingNestedModel (line 741) | def testTrainingNestedModel(self):
method testCustomTrainLoopNestedModel (line 759) | def testCustomTrainLoopNestedModel(self):
method testMutableHypers (line 795) | def testMutableHypers(self, not_mutable, kwargs_update):
method testPositionalArgsFail (line 802) | def testPositionalArgsFail(self):
method testSettingName (line 807) | def testSettingName(self):
method testAdaptiveModelFit (line 824) | def testAdaptiveModelFit(self, adaptive_kwargs):
method testAdaptiveModelFitBatchnorm (line 842) | def testAdaptiveModelFitBatchnorm(self, is_fused):
method testInferredBatchSize (line 861) | def testInferredBatchSize(self):
method testInferredBatchSizeFail (line 882) | def testInferredBatchSizeFail(self, kfac_kwargs):
method testOverrideAdaptiveDefaults (line 891) | def testOverrideAdaptiveDefaults(self):
method testAdaptiveWithLR (line 915) | def testAdaptiveWithLR(self, kfac_kwargs):
method testCustomLossFn (line 925) | def testCustomLossFn(self):
method testRegisterTrainBatch (line 949) | def testRegisterTrainBatch(self):
FILE: kfac/python/kernel_tests/keras_saving_utils_test.py
class SavingUtilsTest (line 55) | class SavingUtilsTest(tf.test.TestCase):
method test_sequential_model_saving (line 58) | def test_sequential_model_saving(self):
method test_functional_model_saving (line 115) | def test_functional_model_saving(self):
method test_saving_model_with_long_layer_names (line 151) | def test_saving_model_with_long_layer_names(self):
method test_saving_model_with_long_weights_names (line 194) | def test_saving_model_with_long_weights_names(self):
method test_model_saving_to_pre_created_h5py_file (line 243) | def test_model_saving_to_pre_created_h5py_file(self):
method test_saving_constant_initializer_with_numpy (line 284) | def test_saving_constant_initializer_with_numpy(self):
FILE: kfac/python/kernel_tests/keras_utils_test.py
function _mlp (line 34) | def _mlp():
function _cnn (line 44) | def _cnn():
function _two_loss_model (line 54) | def _two_loss_model(num_branch1_outputs=1, num_branch2_outputs=9):
class GetLayerCollectionTest (line 70) | class GetLayerCollectionTest(parameterized.TestCase, tf.test.TestCase):
method setUp (line 72) | def setUp(self):
method testValidLogitLossFunctionsCNN (line 84) | def testValidLogitLossFunctionsCNN(self, loss, kfac_loss):
method testValidLogitLossFunctionsMLP (line 105) | def testValidLogitLossFunctionsMLP(self, loss, kfac_loss):
method testValidMSE (line 123) | def testValidMSE(self, loss, model_builder):
method testInvalidLossFunctions (line 139) | def testInvalidLossFunctions(self, loss):
method testLayerRegistration (line 145) | def testLayerRegistration(self, model_builder):
method testMultipleLoss (line 166) | def testMultipleLoss(self, loss, loss_weights):
method testMultipleLossWeights (line 196) | def testMultipleLossWeights(self, loss_weights):
method testLossErrors (line 215) | def testLossErrors(self, loss):
method testLossWeightErrors (line 227) | def testLossWeightErrors(self, loss_weights):
method testInvalidCNNLayers (line 237) | def testInvalidCNNLayers(self, layer):
method testFisherApproxLayerNames (line 248) | def testFisherApproxLayerNames(self, fisher_approx):
method testFisherApproxLayerClass (line 275) | def testFisherApproxLayerClass(self, fisher_approx, block_types):
method testFisherApproxErrors (line 289) | def testFisherApproxErrors(self, fisher_approx):
method testSerializeFisherApprox (line 306) | def testSerializeFisherApprox(self, approx, correctly_serialized_approx):
method testSeed (line 310) | def testSeed(self):
method testNormalizationLayers (line 315) | def testNormalizationLayers(self, has_shift):
method testErrorWithBatchNormNoScale (line 336) | def testErrorWithBatchNormNoScale(self):
method testErrorWithLayerNormNoScale (line 345) | def testErrorWithLayerNormNoScale(self):
method testNumBatchNormUsesWithPhase (line 354) | def testNumBatchNormUsesWithPhase(self):
method testNumBatchNormUsesNoPhase (line 365) | def testNumBatchNormUsesNoPhase(self):
method testModelAsCallable (line 375) | def testModelAsCallable(self):
method testNestedModels (line 405) | def testNestedModels(self, fisher_approx):
method testMultiOutputNestedModelFails (line 447) | def testMultiOutputNestedModelFails(self):
class SerializeLossTest (line 462) | class SerializeLossTest(tf.test.TestCase, parameterized.TestCase):
method testSerializeLoss (line 473) | def testSerializeLoss(self, loss, correctly_serialized_loss):
class GetLossFnTest (line 478) | class GetLossFnTest(tf.test.TestCase, parameterized.TestCase):
method setUp (line 480) | def setUp(self):
method testCrossEntropy (line 492) | def testCrossEntropy(self, loss, label_shape, is_logits, use_regulariz...
method testCrossEntropyCustomLoop (line 531) | def testCrossEntropyCustomLoop(self, loss):
method testMSE (line 558) | def testMSE(self, loss):
method testMultiLoss (line 576) | def testMultiLoss(self, multi_loss, loss_weights):
FILE: kfac/python/kernel_tests/layer_collection_test.py
class MockFisherBlock (line 29) | class MockFisherBlock(object):
method __init__ (line 34) | def __init__(self, name='MockFisherBlock'):
method __eq__ (line 37) | def __eq__(self, other):
method __hash__ (line 40) | def __hash__(self):
class LayerParametersDictTest (line 44) | class LayerParametersDictTest(tf.test.TestCase):
method testSetItem (line 46) | def testSetItem(self):
method testSetItemOverlap (line 64) | def testSetItemOverlap(self):
class LayerCollectionTest (line 81) | class LayerCollectionTest(tf.test.TestCase):
method testLayerCollectionInit (line 83) | def testLayerCollectionInit(self):
method testRegisterBlocks (line 89) | def testRegisterBlocks(self):
method testRegisterBlocksMultipleRegistrations (line 151) | def testRegisterBlocksMultipleRegistrations(self):
method testRegisterSingleParamNotRegistered (line 161) | def testRegisterSingleParamNotRegistered(self):
method testShouldRegisterSingleParamRegistered (line 167) | def testShouldRegisterSingleParamRegistered(self):
method testRegisterSingleParamRegisteredInTuple (line 175) | def testRegisterSingleParamRegisteredInTuple(self):
method testRegisterTupleParamNotRegistered (line 184) | def testRegisterTupleParamNotRegistered(self):
method testRegisterTupleParamRegistered (line 193) | def testRegisterTupleParamRegistered(self):
method testRegisterTupleParamRegisteredInSuperset (line 203) | def testRegisterTupleParamRegisteredInSuperset(self):
method testRegisterTupleParamSomeRegistered (line 214) | def testRegisterTupleParamSomeRegistered(self):
method testRegisterTupleVarSomeRegisteredInOtherTuples (line 225) | def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
method testRegisterCategoricalPredictiveDistribution (line 237) | def testRegisterCategoricalPredictiveDistribution(self):
method testLossFunctionByName (line 252) | def testLossFunctionByName(self):
method testLossFunctionWithoutName (line 271) | def testLossFunctionWithoutName(self):
method testCategoricalPredictiveDistributionMultipleMinibatches (line 282) | def testCategoricalPredictiveDistributionMultipleMinibatches(self):
method testRegisterCategoricalPredictiveDistributionBatchSize1 (line 324) | def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
method testRegisterCategoricalPredictiveDistributionSpecifiedTargets (line 332) | def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
method testRegisterNormalPredictiveDistribution (line 343) | def testRegisterNormalPredictiveDistribution(self):
method testRegisterNormalPredictiveDistributionSpecifiedTargets (line 359) | def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
method ensureLayerReuseWorks (line 371) | def ensureLayerReuseWorks(self, register_fn):
method testRegisterFullyConnectedReuse (line 412) | def testRegisterFullyConnectedReuse(self):
method testRegisterConv2dReuse (line 427) | def testRegisterConv2dReuse(self):
method testReuseWithInvalidRegistration (line 447) | def testReuseWithInvalidRegistration(self):
method testMakeOrGetFactor (line 464) | def testMakeOrGetFactor(self):
method testMakeOrGetFactorCustomScope (line 478) | def testMakeOrGetFactorCustomScope(self):
method testIdentifyLinkedParametersSomeRegisteredInOtherTuples (line 492) | def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
method testIdentifySubsetPreviouslyRegisteredTensor (line 502) | def testIdentifySubsetPreviouslyRegisteredTensor(self):
method testSpecifyApproximation (line 511) | def testSpecifyApproximation(self):
method testDefaultLayerCollection (line 552) | def testDefaultLayerCollection(self):
FILE: kfac/python/kernel_tests/loss_functions_test.py
class InsertSliceInZerosTest (line 28) | class InsertSliceInZerosTest(tf.test.TestCase):
method testBadShape (line 30) | def testBadShape(self):
method test3d (line 35) | def test3d(self):
class CategoricalLogitsNegativeLogProbLossTest (line 44) | class CategoricalLogitsNegativeLogProbLossTest(tf.test.TestCase):
method testSample (line 46) | def testSample(self):
method testEvaluateOnTargets (line 59) | def testEvaluateOnTargets(self):
method testEvaluateOnSample (line 82) | def testEvaluateOnSample(self):
method testMultiplyFisherSingleVector (line 97) | def testMultiplyFisherSingleVector(self):
method testMultiplyFisherBatch (line 116) | def testMultiplyFisherBatch(self):
class OnehotCategoricalLogitsNegativeLogProbLossTest (line 134) | class OnehotCategoricalLogitsNegativeLogProbLossTest(tf.test.TestCase):
method testSample (line 136) | def testSample(self):
method testEvaluateOnTargets (line 149) | def testEvaluateOnTargets(self):
method testEvaluateOnSample (line 172) | def testEvaluateOnSample(self):
FILE: kfac/python/kernel_tests/op_queue_test.py
class OpQueueTest (line 27) | class OpQueueTest(tf.test.TestCase):
method testNextOp (line 29) | def testNextOp(self):
FILE: kfac/python/kernel_tests/optimizer_test.py
function dummy_layer_collection (line 30) | def dummy_layer_collection():
class OptimizerTest (line 37) | class OptimizerTest(tf.test.TestCase):
method testOptimizerInitInvalidMomentumRegistration (line 39) | def testOptimizerInitInvalidMomentumRegistration(self):
method testOptimizerInit (line 44) | def testOptimizerInit(self):
method testSquaredFisherNorm (line 73) | def testSquaredFisherNorm(self):
method testUpdateClipCoeff (line 83) | def testUpdateClipCoeff(self):
method testUpdateVelocities (line 112) | def testUpdateVelocities(self):
method testApplyGradients (line 148) | def testApplyGradients(self):
FILE: kfac/python/kernel_tests/periodic_inv_cov_update_kfac_opt_test.py
function _construct_layer_collection (line 32) | def _construct_layer_collection(layers, all_logits, var_list):
class PeriodicInvCovUpdateKfacOptTest (line 43) | class PeriodicInvCovUpdateKfacOptTest(tf.test.TestCase):
method test_train (line 45) | def test_train(self):
FILE: kfac/python/kernel_tests/utils_test.py
class SequenceDictTest (line 27) | class SequenceDictTest(tf.test.TestCase):
method testSequenceDictInit (line 29) | def testSequenceDictInit(self):
method testSequenceDictInitWithIterable (line 33) | def testSequenceDictInitWithIterable(self):
method testGetItemSingleKey (line 39) | def testGetItemSingleKey(self):
method testGetItemMultipleKeys (line 43) | def testGetItemMultipleKeys(self):
method testSetItemSingleKey (line 47) | def testSetItemSingleKey(self):
method testSetItemMultipleKeys (line 52) | def testSetItemMultipleKeys(self):
class SubGraphTest (line 60) | class SubGraphTest(tf.test.TestCase):
method testBasicGraph (line 62) | def testBasicGraph(self):
method testRepeatedAdds (line 73) | def testRepeatedAdds(self):
method testFilterList (line 82) | def testFilterList(self):
method testVariableUses (line 92) | def testVariableUses(self):
method testVariableUsesRelayOps (line 104) | def testVariableUsesRelayOps(self):
class UtilsTest (line 117) | class UtilsTest(tf.test.TestCase):
method _fully_connected_layer_params (line 119) | def _fully_connected_layer_params(self):
method _conv_layer_params (line 124) | def _conv_layer_params(self):
method testFullyConnectedLayerParamsTupleToMat2d (line 131) | def testFullyConnectedLayerParamsTupleToMat2d(self):
method testFullyConnectedLayerParamsTensorToMat2d (line 140) | def testFullyConnectedLayerParamsTensorToMat2d(self):
method testConvLayerParamsTupleToMat2d (line 148) | def testConvLayerParamsTupleToMat2d(self):
method testKron (line 155) | def testKron(self):
method testMat2dToFullyConnectedLayerParamsTuple (line 165) | def testMat2dToFullyConnectedLayerParamsTuple(self):
method testMat2dToFullyConnectedLayerParamsTensor (line 179) | def testMat2dToFullyConnectedLayerParamsTensor(self):
method testTensorsToColumn (line 189) | def testTensorsToColumn(self):
method testColumnToTensors (line 213) | def testColumnToTensors(self):
method testPosDefInvCholesky (line 243) | def testPosDefInvCholesky(self):
method testPosDefInvMatrixInverse (line 258) | def testPosDefInvMatrixInverse(self):
method testBatchExecute (line 273) | def testBatchExecute(self):
method testExtractConvolutionPatches (line 309) | def testExtractConvolutionPatches(self):
method testExtractPointwiseConv2dPatches (line 354) | def testExtractPointwiseConv2dPatches(self):
class AccumulatorVariableTest (line 388) | class AccumulatorVariableTest(tf.test.TestCase):
method test_assign_to_var (line 390) | def test_assign_to_var(self):
method test_accumulation (line 426) | def test_accumulation(self):
FILE: kfac/python/ops/curvature_matrix_vector_products.py
class CurvatureMatrixVectorProductComputer (line 28) | class CurvatureMatrixVectorProductComputer(object):
method __init__ (line 59) | def __init__(self, layer_collection, wrt_tensors,
method _loss_colocation_ops (line 77) | def _loss_colocation_ops(self):
method _losses (line 81) | def _losses(self):
method _inputs_to_losses (line 85) | def _inputs_to_losses(self):
method _inputs_to_losses_flat (line 89) | def _inputs_to_losses_flat(self):
method _total_loss (line 93) | def _total_loss(self):
method _get_loss_coeff (line 96) | def _get_loss_coeff(self, loss):
method _multiply_jacobian (line 100) | def _multiply_jacobian(self, vecs):
method _multiply_jacobian_transpose (line 110) | def _multiply_jacobian_transpose(self, loss_vecs):
method _multiply_across_losses (line 123) | def _multiply_across_losses(self, mult_func, vecs, coeff_mode="regular"):
method _multiply_loss_fisher (line 135) | def _multiply_loss_fisher(self, loss_vecs):
method _multiply_loss_fisher_factor (line 140) | def _multiply_loss_fisher_factor(self, loss_inner_vecs):
method _multiply_loss_fisher_factor_transpose (line 146) | def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
method _multiply_loss_ggn (line 152) | def _multiply_loss_ggn(self, loss_vecs):
method _multiply_loss_ggn_factor (line 157) | def _multiply_loss_ggn_factor(self, loss_inner_vecs):
method _multiply_loss_ggn_factor_transpose (line 163) | def _multiply_loss_ggn_factor_transpose(self, loss_vecs):
method multiply_fisher (line 170) | def multiply_fisher(self, vecs):
method multiply_fisher_factor_transpose (line 176) | def multiply_fisher_factor_transpose(self, vecs):
method multiply_fisher_factor (line 181) | def multiply_fisher_factor(self, loss_inner_vecs):
method multiply_hessian (line 187) | def multiply_hessian(self, vecs):
method multiply_ggn (line 198) | def multiply_ggn(self, vecs):
method multiply_ggn_factor_transpose (line 204) | def multiply_ggn_factor_transpose(self, vecs):
method multiply_ggn_factor (line 209) | def multiply_ggn_factor(self, loss_inner_vecs):
method fisher_factor_inner_shapes (line 217) | def fisher_factor_inner_shapes(self):
method fisher_factor_inner_static_shapes (line 222) | def fisher_factor_inner_static_shapes(self):
method ggn_factor_inner_shapes (line 227) | def ggn_factor_inner_shapes(self):
method ggn_factor_inner_static_shapes (line 232) | def ggn_factor_inner_static_shapes(self):
FILE: kfac/python/ops/estimator.py
function make_fisher_estimator (line 35) | def make_fisher_estimator(placement_strategy=None, **kwargs):
class FisherEstimator (line 69) | class FisherEstimator(object):
method __init__ (line 79) | def __init__(self,
method variables (line 196) | def variables(self):
method damping (line 200) | def damping(self):
method blocks (line 204) | def blocks(self):
method factors (line 209) | def factors(self):
method name (line 214) | def name(self):
method layers (line 218) | def layers(self):
method mat_type (line 222) | def mat_type(self):
method params_stats (line 226) | def params_stats(self):
method _place_and_compute_transformation_thunks (line 230) | def _place_and_compute_transformation_thunks(self, thunks, params_list):
method _compute_transformation (line 248) | def _compute_transformation(self, vecs_and_vars, transform):
method multiply_inverse (line 281) | def multiply_inverse(self, vecs_and_vars):
method multiply (line 293) | def multiply(self, vecs_and_vars):
method multiply_matpower (line 305) | def multiply_matpower(self, exp, vecs_and_vars):
method multiply_cholesky (line 320) | def multiply_cholesky(self, vecs_and_vars, transpose=False):
method multiply_cholesky_inverse (line 336) | def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
method _instantiate_factors (line 361) | def _instantiate_factors(self):
method _register_matrix_functions (line 415) | def _register_matrix_functions(self):
method _finalize (line 424) | def _finalize(self):
method _check_batch_sizes (line 433) | def _check_batch_sizes(self, factor):
method _create_ops_and_vars_thunks (line 463) | def _create_ops_and_vars_thunks(self, scope=None):
method create_ops_and_vars_thunks (line 515) | def create_ops_and_vars_thunks(self, scope=None):
method make_vars_and_create_op_thunks (line 548) | def make_vars_and_create_op_thunks(self, scope=None):
method get_cov_vars (line 580) | def get_cov_vars(self):
method get_inv_vars (line 592) | def get_inv_vars(self):
method _create_cov_variable_thunk (line 604) | def _create_cov_variable_thunk(self, factor, scope):
method _create_cov_update_thunk (line 613) | def _create_cov_update_thunk(self, factor, scope):
method _create_inv_variable_thunk (line 631) | def _create_inv_variable_thunk(self, factor, scope):
method _create_inv_update_thunk (line 640) | def _create_inv_update_thunk(self, factor, scope):
method _get_grads_lists_gradients (line 649) | def _get_grads_lists_gradients(self, tensors):
method _get_grads_lists_empirical (line 659) | def _get_grads_lists_empirical(self, tensors):
method _get_transformed_random_signs (line 669) | def _get_transformed_random_signs(self):
method _get_grads_lists_curvature_prop (line 687) | def _get_grads_lists_curvature_prop(self, tensors):
method _get_grads_lists_exact (line 698) | def _get_grads_lists_exact(self, tensors):
class FisherEstimatorRoundRobin (line 727) | class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
class FisherEstimatorReplicaRoundRobin (line 733) | class FisherEstimatorReplicaRoundRobin(
FILE: kfac/python/ops/fisher_blocks.py
function set_global_constants (line 63) | def set_global_constants(normalize_damping_power=None, pi_type=None):
function normalize_damping (line 75) | def normalize_damping(damping, num_replications):
function compute_pi_tracenorm (line 82) | def compute_pi_tracenorm(left_cov, right_cov):
function compute_pi_adjusted_damping (line 115) | def compute_pi_adjusted_damping(left_cov, right_cov, damping):
class PackagedFunc (line 126) | class PackagedFunc(object):
method __init__ (line 132) | def __init__(self, func, func_id):
method __call__ (line 144) | def __call__(self):
method func_id (line 148) | def func_id(self):
function _package_func (line 153) | def _package_func(func, func_id):
class FisherBlock (line 158) | class FisherBlock(object):
method __init__ (line 166) | def __init__(self, layer_collection):
method instantiate_factors (line 170) | def instantiate_factors(self, grads_list, damping):
method register_matpower (line 182) | def register_matpower(self, exp):
method register_cholesky (line 191) | def register_cholesky(self):
method register_cholesky_inverse (line 196) | def register_cholesky_inverse(self):
method register_inverse (line 200) | def register_inverse(self):
method multiply_matpower (line 205) | def multiply_matpower(self, vector, exp):
method multiply_inverse (line 218) | def multiply_inverse(self, vector):
method multiply (line 229) | def multiply(self, vector):
method multiply_cholesky (line 241) | def multiply_cholesky(self, vector, transpose=False):
method multiply_cholesky_inverse (line 255) | def multiply_cholesky_inverse(self, vector, transpose=False):
method tensors_to_compute_grads (line 268) | def tensors_to_compute_grads(self):
method num_registered_towers (line 274) | def num_registered_towers(self):
class FullFB (line 283) | class FullFB(FisherBlock):
method register_matpower (line 286) | def register_matpower(self, exp):
method register_cholesky (line 289) | def register_cholesky(self):
method register_cholesky_inverse (line 292) | def register_cholesky_inverse(self):
method _multiply_matrix (line 295) | def _multiply_matrix(self, matrix, vector, transpose=False):
method multiply_matpower (line 300) | def multiply_matpower(self, vector, exp):
method multiply_cholesky (line 304) | def multiply_cholesky(self, vector, transpose=False):
method multiply_cholesky_inverse (line 308) | def multiply_cholesky_inverse(self, vector, transpose=False):
method full_fisher_block (line 312) | def full_fisher_block(self):
class NaiveFullFB (line 317) | class NaiveFullFB(FullFB):
method __init__ (line 327) | def __init__(self, layer_collection, params):
method instantiate_factors (line 339) | def instantiate_factors(self, grads_list, damping):
method tensors_to_compute_grads (line 345) | def tensors_to_compute_grads(self):
method register_additional_tower (line 348) | def register_additional_tower(self, batch_size):
method num_registered_towers (line 357) | def num_registered_towers(self):
method _batch_size (line 361) | def _batch_size(self):
class DiagonalFB (line 366) | class DiagonalFB(FisherBlock):
method register_matpower (line 369) | def register_matpower(self, exp):
method register_cholesky (line 374) | def register_cholesky(self):
method register_cholesky_inverse (line 379) | def register_cholesky_inverse(self):
method _multiply_matrix (line 384) | def _multiply_matrix(self, matrix, vector):
method multiply_matpower (line 389) | def multiply_matpower(self, vector, exp):
method multiply_cholesky (line 393) | def multiply_cholesky(self, vector, transpose=False):
method multiply_cholesky_inverse (line 397) | def multiply_cholesky_inverse(self, vector, transpose=False):
method full_fisher_block (line 401) | def full_fisher_block(self):
class NaiveDiagonalFB (line 405) | class NaiveDiagonalFB(DiagonalFB):
method __init__ (line 414) | def __init__(self, layer_collection, params):
method instantiate_factors (line 426) | def instantiate_factors(self, grads_list, damping):
method tensors_to_compute_grads (line 432) | def tensors_to_compute_grads(self):
method register_additional_tower (line 435) | def register_additional_tower(self, batch_size):
method num_registered_towers (line 444) | def num_registered_towers(self):
method _batch_size (line 448) | def _batch_size(self):
class InputOutputMultiTower (line 452) | class InputOutputMultiTower(object):
method __init__ (line 455) | def __init__(self, *args, **kwargs):
method _process_data (line 460) | def _process_data(self, grads_list):
method tensors_to_compute_grads (line 514) | def tensors_to_compute_grads(self):
method register_additional_tower (line 518) | def register_additional_tower(self, inputs, outputs):
method num_registered_towers (line 523) | def num_registered_towers(self):
method _inputs (line 529) | def _inputs(self):
method _outputs (line 533) | def _outputs(self):
class FullyConnectedDiagonalFB (line 537) | class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
method __init__ (line 560) | def __init__(self, layer_collection, has_bias=False):
method instantiate_factors (line 573) | def instantiate_factors(self, grads_list, damping):
class ConvDiagonalFB (line 583) | class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
method __init__ (line 608) | def __init__(self,
method _factor_implementation (line 669) | def _factor_implementation(self):
method instantiate_factors (line 672) | def instantiate_factors(self, grads_list, damping):
class ScaleAndShiftFullFB (line 695) | class ScaleAndShiftFullFB(InputOutputMultiTower, FullFB):
method __init__ (line 704) | def __init__(self, layer_collection, broadcast_dims_scale,
method instantiate_factors (line 723) | def instantiate_factors(self, grads_list, damping):
class ScaleAndShiftDiagonalFB (line 735) | class ScaleAndShiftDiagonalFB(InputOutputMultiTower, DiagonalFB):
method __init__ (line 744) | def __init__(self, layer_collection, broadcast_dims_scale,
method instantiate_factors (line 763) | def instantiate_factors(self, grads_list, damping):
class KroneckerProductFB (line 775) | class KroneckerProductFB(FisherBlock):
method _setup_damping (line 782) | def _setup_damping(self, damping, normalization=None):
method register_matpower (line 818) | def register_matpower(self, exp):
method register_cholesky (line 822) | def register_cholesky(self):
method register_cholesky_inverse (line 826) | def register_cholesky_inverse(self):
method damping (line 831) | def damping(self):
method input_factor (line 843) | def input_factor(self):
method output_factor (line 847) | def output_factor(self):
method _renorm_coeff (line 851) | def _renorm_coeff(self):
method _multiply_factored_matrix (line 862) | def _multiply_factored_matrix(self, left_factor, right_factor, vector,
method multiply_matpower (line 875) | def multiply_matpower(self, vector, exp):
method multiply_cholesky (line 884) | def multiply_cholesky(self, vector, transpose=False):
method multiply_cholesky_inverse (line 893) | def multiply_cholesky_inverse(self, vector, transpose=False):
method full_fisher_block (line 904) | def full_fisher_block(self):
class FullyConnectedKFACBasicFB (line 918) | class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
method __init__ (line 925) | def __init__(self, layer_collection, has_bias=False,
method instantiate_factors (line 946) | def instantiate_factors(self, grads_list, damping):
class ConvKFCBasicFB (line 979) | class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
method __init__ (line 1003) | def __init__(self,
method instantiate_factors (line 1063) | def instantiate_factors(self, grads_list, damping):
method _renorm_coeff (line 1090) | def _renorm_coeff(self):
class DepthwiseConvDiagonalFB (line 1094) | class DepthwiseConvDiagonalFB(ConvDiagonalFB):
method __init__ (line 1100) | def __init__(self,
method _multiply_matrix (line 1156) | def _multiply_matrix(self, matrix, vector):
class DepthwiseConvKFCBasicFB (line 1163) | class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
method __init__ (line 1169) | def __init__(self,
method _multiply_factored_matrix (line 1226) | def _multiply_factored_matrix(self, left_factor, right_factor, vector,
function depthwise_conv2d_filter_to_conv2d_filter (line 1237) | def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pyli...
function conv2d_filter_to_depthwise_conv2d_filter (line 1280) | def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pyli...
function maybe_tuple (line 1324) | def maybe_tuple(obj):
class InputOutputMultiTowerMultiUse (line 1330) | class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
method __init__ (line 1333) | def __init__(self, num_uses=None, *args, **kwargs):
method _process_data (line 1337) | def _process_data(self, grads_list):
class FullyConnectedMultiIndepFB (line 1487) | class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
method __init__ (line 1496) | def __init__(self, layer_collection, has_bias=False, num_uses=None,
method instantiate_factors (line 1522) | def instantiate_factors(self, grads_list, damping):
method _renorm_coeff (line 1546) | def _renorm_coeff(self):
class ConvKFCBasicMultiIndepFB (line 1550) | class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
method __init__ (line 1560) | def __init__(self,
method instantiate_factors (line 1607) | def instantiate_factors(self, grads_list, damping):
method _renorm_coeff (line 1629) | def _renorm_coeff(self):
class SeriesFBApproximation (line 1633) | class SeriesFBApproximation(object):
class FullyConnectedSeriesFB (line 1639) | class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
method __init__ (line 1653) | def __init__(self,
method _num_timesteps (line 1684) | def _num_timesteps(self):
method _renorm_coeff (line 1688) | def _renorm_coeff(self):
method instantiate_factors (line 1693) | def instantiate_factors(self, grads_list, damping):
method register_matpower (line 1707) | def register_matpower(self, exp):
method multiply_matpower (line 1723) | def multiply_matpower(self, vector, exp):
method multiply_cholesky (line 1836) | def multiply_cholesky(self, vector):
method multiply_cholesky_inverse (line 1840) | def multiply_cholesky_inverse(self, vector):
FILE: kfac/python/ops/fisher_factors.py
function set_global_constants (line 127) | def set_global_constants(init_covariances_at_zero=None,
function maybe_place_on_device (line 211) | def maybe_place_on_device(device):
function compute_cov (line 219) | def compute_cov(tensor, tensor_right=None, normalizer=None):
function append_homog (line 247) | def append_homog(tensor, homog_value=None):
function scope_string_from_params (line 272) | def scope_string_from_params(params):
function scope_string_from_name (line 316) | def scope_string_from_name(tensor):
function scalar_or_tensor_to_string (line 325) | def scalar_or_tensor_to_string(val):
function list_to_string (line 329) | def list_to_string(lst):
function graph_func_to_id (line 334) | def graph_func_to_id(func):
function graph_func_to_string (line 340) | def graph_func_to_string(func):
function _subsample_patches (line 345) | def _subsample_patches(patches, name=None):
function _random_tensor_gather (line 390) | def _random_tensor_gather(array, num_ind, name=None):
class FisherFactor (line 412) | class FisherFactor(object):
method __init__ (line 429) | def __init__(self):
method _var_scope (line 435) | def _var_scope(self):
method name (line 444) | def name(self):
method _cov_shape (line 448) | def _cov_shape(self):
method _num_sources (line 453) | def _num_sources(self):
method _num_towers (line 465) | def _num_towers(self):
method _dtype (line 469) | def _dtype(self):
method _partial_batch_size (line 474) | def _partial_batch_size(self, source=0, tower=0):
method batch_size (line 478) | def batch_size(self, source=0):
method check_partial_batch_sizes (line 483) | def check_partial_batch_sizes(self):
method _cov_initializer (line 529) | def _cov_initializer(self):
method instantiate_cov_variables (line 533) | def instantiate_cov_variables(self):
method _compute_new_cov (line 545) | def _compute_new_cov(self, source, tower):
method _compute_total_new_cov (line 559) | def _compute_total_new_cov(self):
method make_covariance_update_op (line 585) | def make_covariance_update_op(self, ema_decay, ema_weight):
method _get_data_device (line 604) | def _get_data_device(self, tower):
method instantiate_inv_variables (line 608) | def instantiate_inv_variables(self):
method make_inverse_update_ops (line 613) | def make_inverse_update_ops(self):
method cov (line 618) | def cov(self):
method get_cov_vars (line 621) | def get_cov_vars(self):
method get_inv_vars (line 624) | def get_inv_vars(self):
method get_cov_as_linear_operator (line 628) | def get_cov_as_linear_operator(self):
method register_matpower (line 633) | def register_matpower(self, exp, damping_func):
method register_cholesky (line 637) | def register_cholesky(self, damping_func):
method register_cholesky_inverse (line 641) | def register_cholesky_inverse(self, damping_func):
method get_matpower (line 645) | def get_matpower(self, exp, damping_func):
method get_cholesky (line 649) | def get_cholesky(self, damping_func):
method get_cholesky_inverse (line 653) | def get_cholesky_inverse(self, damping_func):
class DenseSquareMatrixFactor (line 657) | class DenseSquareMatrixFactor(FisherFactor):
method __init__ (line 674) | def __init__(self):
method get_cov_as_linear_operator (line 688) | def get_cov_as_linear_operator(self):
method _register_damping (line 695) | def _register_damping(self, damping_func):
method register_inverse (line 701) | def register_inverse(self, damping_func):
method register_matpower (line 705) | def register_matpower(self, exp, damping_func):
method register_cholesky (line 725) | def register_cholesky(self, damping_func):
method register_cholesky_inverse (line 741) | def register_cholesky_inverse(self, damping_func):
method get_inv_vars (line 757) | def get_inv_vars(self):
method instantiate_inv_variables (line 764) | def instantiate_inv_variables(self):
method make_inverse_update_ops (line 810) | def make_inverse_update_ops(self):
method get_inverse (line 882) | def get_inverse(self, damping_func):
method get_matpower (line 886) | def get_matpower(self, exp, damping_func):
method get_cholesky (line 905) | def get_cholesky(self, damping_func):
method get_cholesky_inverse (line 916) | def get_cholesky_inverse(self, damping_func):
method get_eigendecomp (line 927) | def get_eigendecomp(self):
class NaiveFullFactor (line 944) | class NaiveFullFactor(DenseSquareMatrixFactor):
method __init__ (line 951) | def __init__(self,
method _var_scope (line 960) | def _var_scope(self):
method _cov_shape (line 965) | def _cov_shape(self):
method _num_sources (line 971) | def _num_sources(self):
method _num_towers (line 975) | def _num_towers(self):
method _dtype (line 979) | def _dtype(self):
method _partial_batch_size (line 982) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 986) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 994) | def _get_data_device(self, tower):
class DiagonalFactor (line 999) | class DiagonalFactor(FisherFactor):
method get_cov_as_linear_operator (line 1006) | def get_cov_as_linear_operator(self):
method _cov_initializer (line 1013) | def _cov_initializer(self):
method _matrix_diagonal (line 1017) | def _matrix_diagonal(self):
method make_inverse_update_ops (line 1020) | def make_inverse_update_ops(self):
method instantiate_inv_variables (line 1023) | def instantiate_inv_variables(self):
method register_matpower (line 1026) | def register_matpower(self, exp, damping_func):
method register_cholesky (line 1029) | def register_cholesky(self, damping_func):
method register_cholesky_inverse (line 1032) | def register_cholesky_inverse(self, damping_func):
method get_matpower (line 1035) | def get_matpower(self, exp, damping_func):
method get_cholesky (line 1044) | def get_cholesky(self, damping_func):
method get_cholesky_inverse (line 1047) | def get_cholesky_inverse(self, damping_func):
class NaiveDiagonalFactor (line 1051) | class NaiveDiagonalFactor(DiagonalFactor):
method __init__ (line 1058) | def __init__(self,
method _var_scope (line 1074) | def _var_scope(self):
method _cov_shape (line 1079) | def _cov_shape(self):
method _num_sources (line 1083) | def _num_sources(self):
method _num_towers (line 1087) | def _num_towers(self):
method _dtype (line 1091) | def _dtype(self):
method _partial_batch_size (line 1094) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 1098) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 1103) | def _get_data_device(self, tower):
class DiagonalKroneckerFactor (line 1107) | class DiagonalKroneckerFactor(DiagonalFactor):
method __init__ (line 1128) | def __init__(self, tensors, has_bias=False, dtype=None):
method _var_scope (line 1157) | def _var_scope(self):
method _cov_shape (line 1162) | def _cov_shape(self):
method _num_sources (line 1170) | def _num_sources(self):
method _num_towers (line 1174) | def _num_towers(self):
method _dtype (line 1178) | def _dtype(self):
method _partial_batch_size (line 1181) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 1184) | def _compute_new_cov(self, source, tower):
method _compute_new_cov_from_tensor (line 1187) | def _compute_new_cov_from_tensor(self, tensor):
method _get_data_device (line 1226) | def _get_data_device(self, tower):
class DiagonalMultiKF (line 1230) | class DiagonalMultiKF(DiagonalKroneckerFactor):
method __init__ (line 1232) | def __init__(self, tensors, num_uses, has_bias=False, dtype=None):
method _partial_batch_size (line 1237) | def _partial_batch_size(self, source=0, tower=0):
method _cov_shape (line 1257) | def _cov_shape(self):
method _compute_new_cov (line 1268) | def _compute_new_cov(self, source, tower):
class FullyConnectedDiagonalFactor (line 1280) | class FullyConnectedDiagonalFactor(DiagonalFactor):
method __init__ (line 1291) | def __init__(self,
method _var_scope (line 1314) | def _var_scope(self):
method _cov_shape (line 1319) | def _cov_shape(self):
method _num_sources (line 1325) | def _num_sources(self):
method _num_towers (line 1329) | def _num_towers(self):
method _dtype (line 1333) | def _dtype(self):
method _partial_batch_size (line 1336) | def _partial_batch_size(self, source=0, tower=0):
method make_covariance_update_op (line 1339) | def make_covariance_update_op(self, ema_decay, ema_weight):
method _compute_new_cov (line 1353) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 1367) | def _get_data_device(self, tower):
class ScaleAndShiftFactor (line 1372) | class ScaleAndShiftFactor(FisherFactor):
method __init__ (line 1374) | def __init__(self,
method _var_scope (line 1396) | def _var_scope(self):
method _cov_shape (line 1402) | def _cov_shape(self):
method _num_sources (line 1423) | def _num_sources(self):
method _num_towers (line 1427) | def _num_towers(self):
method _dtype (line 1431) | def _dtype(self):
method _partial_batch_size (line 1434) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 1437) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 1478) | def _get_data_device(self, tower):
class ScaleAndShiftFullFactor (line 1482) | class ScaleAndShiftFullFactor(ScaleAndShiftFactor, DenseSquareMatrixFact...
method __init__ (line 1484) | def __init__(self,
class ScaleAndShiftDiagonalFactor (line 1500) | class ScaleAndShiftDiagonalFactor(ScaleAndShiftFactor, DiagonalFactor):
method __init__ (line 1502) | def __init__(self,
class ConvDiagonalFactor (line 1518) | class ConvDiagonalFactor(DiagonalFactor):
method __init__ (line 1521) | def __init__(self,
method _var_scope (line 1592) | def _var_scope(self):
method _cov_shape (line 1597) | def _cov_shape(self):
method _num_sources (line 1605) | def _num_sources(self):
method _num_towers (line 1609) | def _num_towers(self):
method _dtype (line 1613) | def _dtype(self):
method _partial_batch_size (line 1616) | def _partial_batch_size(self, source=0, tower=0):
method make_covariance_update_op (line 1619) | def make_covariance_update_op(self, ema_decay, ema_weight):
method _compute_new_cov (line 1652) | def _compute_new_cov(self, source, tower):
method _convdiag_sum_of_squares (line 1663) | def _convdiag_sum_of_squares(self, patches, outputs_grad):
method _get_data_device (line 1670) | def _get_data_device(self, tower):
class FullyConnectedKroneckerFactor (line 1674) | class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
method __init__ (line 1678) | def __init__(self,
method _var_scope (line 1704) | def _var_scope(self):
method _cov_shape (line 1709) | def _cov_shape(self):
method _num_sources (line 1714) | def _num_sources(self):
method _num_towers (line 1718) | def _num_towers(self):
method _dtype (line 1722) | def _dtype(self):
method _partial_batch_size (line 1725) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 1728) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 1734) | def _get_data_device(self, tower):
class ConvInputKroneckerFactor (line 1738) | class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
method __init__ (line 1750) | def __init__(self,
method _var_scope (line 1814) | def _var_scope(self):
method _cov_shape (line 1822) | def _cov_shape(self):
method _num_sources (line 1829) | def _num_sources(self):
method _num_towers (line 1833) | def _num_towers(self):
method _dtype (line 1837) | def _dtype(self):
method _partial_batch_size (line 1840) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 1844) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 1936) | def _get_data_device(self, tower):
class ConvInputMultiKF (line 1940) | class ConvInputMultiKF(ConvInputKroneckerFactor):
method __init__ (line 1942) | def __init__(self,
method _partial_batch_size (line 1969) | def _partial_batch_size(self, source=0, tower=0):
class ConvInputSUAKroneckerFactor (line 1979) | class ConvInputSUAKroneckerFactor(FisherFactor):
method __init__ (line 1988) | def __init__(self, inputs, filter_shape, has_bias=False):
method _var_scope (line 2021) | def _var_scope(self):
method _cov_shape (line 2026) | def _cov_shape(self):
method _num_sources (line 2039) | def _num_sources(self):
method _num_towers (line 2043) | def _num_towers(self):
method _dtype (line 2047) | def _dtype(self):
method mu (line 2051) | def mu(self):
method _partial_batch_size (line 2054) | def _partial_batch_size(self, source=0, tower=0):
method _register_damping (line 2058) | def _register_damping(self, damping_func):
method get_inv_vars (line 2064) | def get_inv_vars(self):
method instantiate_cov_variables (line 2069) | def instantiate_cov_variables(self):
method make_covariance_update_op (line 2087) | def make_covariance_update_op(self, ema_decay, ema_weight):
method _compute_new_cov (line 2124) | def _compute_new_cov(self, source, tower):
method register_matpower (line 2135) | def register_matpower(self, exp, damping_func):
method _compute_sm_rank_one_update_quants (line 2159) | def _compute_sm_rank_one_update_quants(self, exp, damping_id, damping_...
method get_matpower (line 2173) | def get_matpower(self, exp, damping_func):
method make_inverse_update_ops (line 2256) | def make_inverse_update_ops(self):
method get_inverse (line 2288) | def get_inverse(self, damping_func):
method instantiate_inv_variables (line 2292) | def instantiate_inv_variables(self):
method _make_cov_linear_operator (line 2340) | def _make_cov_linear_operator(self, damping=None):
method get_cov_as_linear_operator (line 2385) | def get_cov_as_linear_operator(self):
method get_cholesky (line 2388) | def get_cholesky(self, damping_func):
method get_cholesky_inverse (line 2392) | def get_cholesky_inverse(self, damping_func):
method register_cholesky (line 2396) | def register_cholesky(self):
method register_cholesky_inverse (line 2400) | def register_cholesky_inverse(self):
method _get_data_device (line 2404) | def _get_data_device(self, tower):
class ConvOutputKroneckerFactor (line 2408) | class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
method __init__ (line 2419) | def __init__(self, outputs_grads, data_format=None):
method _var_scope (line 2438) | def _var_scope(self):
method _cov_shape (line 2443) | def _cov_shape(self):
method _num_sources (line 2448) | def _num_sources(self):
method _num_towers (line 2452) | def _num_towers(self):
method _dtype (line 2456) | def _dtype(self):
method _partial_batch_size (line 2459) | def _partial_batch_size(self, source=0, tower=0):
method _compute_new_cov (line 2462) | def _compute_new_cov(self, source, tower):
method _get_data_device (line 2476) | def _get_data_device(self, tower):
class ConvOutputMultiKF (line 2480) | class ConvOutputMultiKF(ConvOutputKroneckerFactor):
method __init__ (line 2482) | def __init__(self, outputs_grads, num_uses, data_format=None):
method _partial_batch_size (line 2487) | def _partial_batch_size(self, source=0, tower=0):
class FullyConnectedMultiKF (line 2497) | class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
method __init__ (line 2500) | def __init__(self,
method _num_timesteps (line 2530) | def _num_timesteps(self):
method _partial_batch_size (line 2533) | def _partial_batch_size(self, source=0, tower=0):
method _var_scope (line 2542) | def _var_scope(self):
method get_inv_vars (line 2547) | def get_inv_vars(self):
method make_covariance_update_op (line 2553) | def make_covariance_update_op(self, ema_decay, ema_weight):
method _compute_new_cov (line 2582) | def _compute_new_cov(self, source, tower):
method _compute_new_cov_dt1 (line 2591) | def _compute_new_cov_dt1(self, source, tower): # pylint: disable=miss...
method _cov_shape (line 2614) | def _cov_shape(self):
method _get_data_device (line 2622) | def _get_data_device(self, tower):
method _vec_shape (line 2626) | def _vec_shape(self):
method get_option1quants (line 2630) | def get_option1quants(self, damping_func):
method get_option2quants (line 2634) | def get_option2quants(self, damping_func):
method cov_dt1 (line 2639) | def cov_dt1(self):
method get_cov_vars (line 2643) | def get_cov_vars(self):
method register_cov_dt1 (line 2649) | def register_cov_dt1(self):
method instantiate_cov_variables (line 2652) | def instantiate_cov_variables(self):
method register_option1quants (line 2664) | def register_option1quants(self, damping_func):
method register_option2quants (line 2669) | def register_option2quants(self, damping_func):
method instantiate_inv_variables (line 2674) | def instantiate_inv_variables(self):
method make_inverse_update_ops (line 2734) | def make_inverse_update_ops(self):
FILE: kfac/python/ops/kfac_utils/async_inv_cov_update_kfac_opt.py
class AsyncInvCovUpdateKfacOpt (line 30) | class AsyncInvCovUpdateKfacOpt(optimizer.KfacOptimizer):
method __init__ (line 49) | def __init__(self,
method _make_ops (line 80) | def _make_ops(self, update_thunks):
method apply_gradients (line 83) | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
method run_cov_inv_ops (line 94) | def run_cov_inv_ops(self, sess):
method _run_ops (line 105) | def _run_ops(self, sess):
method stop_cov_inv_ops (line 120) | def stop_cov_inv_ops(self, sess):
method _set_up_op_name_queue (line 125) | def _set_up_op_name_queue(self, ops_to_run):
FILE: kfac/python/ops/kfac_utils/data_reader.py
function _slice_data (line 30) | def _slice_data(stored_data, size):
class VariableBatchReader (line 34) | class VariableBatchReader(object):
method __init__ (line 37) | def __init__(self, dataset, max_batch_size):
method __call__ (line 52) | def __call__(self, batch_size):
class CachedDataReader (line 72) | class CachedDataReader(VariableBatchReader):
method __init__ (line 75) | def __init__(self, dataset, max_batch_size):
method __call__ (line 104) | def __call__(self, batch_size):
method cached_batch (line 128) | def cached_batch(self):
FILE: kfac/python/ops/kfac_utils/data_reader_alt.py
function _extract_data (line 34) | def _extract_data(tensor_list, indices):
class VariableBatchReader (line 38) | class VariableBatchReader(object):
method __init__ (line 41) | def __init__(self, dataset, num_examples):
method __call__ (line 54) | def __call__(self, batch_size):
class CachedDataReader (line 76) | class CachedDataReader(VariableBatchReader):
method __init__ (line 79) | def __init__(self, dataset, num_examples):
method __call__ (line 105) | def __call__(self, batch_size):
method cached_batch (line 126) | def cached_batch(self):
FILE: kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py
class PeriodicInvCovUpdateKfacOpt (line 31) | class PeriodicInvCovUpdateKfacOpt(optimizer.KfacOptimizer):
method __init__ (line 52) | def __init__(self,
method minimize (line 123) | def minimize(self,
method apply_gradients (line 170) | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
method kfac_update_ops (line 177) | def kfac_update_ops(self):
FILE: kfac/python/ops/layer_collection.py
function get_default_layer_collection (line 107) | def get_default_layer_collection():
function set_default_layer_collection (line 117) | def set_default_layer_collection(layer_collection):
class LayerParametersDict (line 126) | class LayerParametersDict(OrderedDict):
method __init__ (line 132) | def __init__(self, *args, **kwargs):
method __setitem__ (line 136) | def __setitem__(self, key, value):
method __delitem__ (line 145) | def __delitem__(self, key):
method __getitem__ (line 150) | def __getitem__(self, key):
method __contains__ (line 154) | def __contains__(self, key):
method _canonicalize_key (line 158) | def _canonicalize_key(self, key):
class LayerCollection (line 168) | class LayerCollection(object):
method __init__ (line 194) | def __init__(self,
method losses (line 276) | def losses(self):
method towers_by_loss (line 281) | def towers_by_loss(self):
method registered_variables (line 286) | def registered_variables(self):
method linked_parameters (line 294) | def linked_parameters(self):
method default_generic_approximation (line 308) | def default_generic_approximation(self):
method set_default_generic_approximation (line 311) | def set_default_generic_approximation(self, value):
method default_fully_connected_approximation (line 319) | def default_fully_connected_approximation(self):
method set_default_fully_connected_approximation (line 322) | def set_default_fully_connected_approximation(self, value):
method default_conv2d_approximation (line 330) | def default_conv2d_approximation(self):
method set_default_conv2d_approximation (line 333) | def set_default_conv2d_approximation(self, value):
method default_fully_connected_multi_approximation (line 341) | def default_fully_connected_multi_approximation(self):
method set_default_fully_connected_multi_approximation (line 344) | def set_default_fully_connected_multi_approximation(self, value):
method default_conv2d_multi_approximation (line 351) | def default_conv2d_multi_approximation(self):
method set_default_conv2d_multi_approximation (line 354) | def set_default_conv2d_multi_approximation(self, value):
method default_scale_and_shift_approximation (line 361) | def default_scale_and_shift_approximation(self):
method set_default_scale_and_shift_approximation (line 364) | def set_default_scale_and_shift_approximation(self, value):
method auto_register_layers (line 370) | def auto_register_layers(self, var_list=None, batch_size=None):
method finalize (line 411) | def finalize(self):
method _register_block (line 421) | def _register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
method _register_loss_function (line 496) | def _register_loss_function(self,
method _get_use_count_map (line 567) | def _get_use_count_map(self):
method _add_uses (line 571) | def _add_uses(self, params, uses):
method check_registration (line 582) | def check_registration(self, variables):
method get_blocks (line 626) | def get_blocks(self):
method get_factors (line 629) | def get_factors(self):
method graph (line 633) | def graph(self):
method subgraph (line 637) | def subgraph(self):
method define_linked_parameters (line 640) | def define_linked_parameters(self, params, approximation=None):
method _create_subgraph (line 686) | def _create_subgraph(self):
method eval_losses (line 692) | def eval_losses(self, target_mode="data", coeff_mode="regular"):
method total_loss (line 716) | def total_loss(self, coeff_mode="regular"):
method total_sampled_loss (line 720) | def total_sampled_loss(self, coeff_mode="regular"):
method _get_linked_approx (line 724) | def _get_linked_approx(self, params):
method _get_block_type (line 732) | def _get_block_type(self, params, approx, default, approx_to_type):
method register_fully_connected (line 743) | def register_fully_connected(self,
method register_conv1d (line 796) | def register_conv1d(self,
method register_conv2d (line 857) | def register_conv2d(self,
method register_convolution (line 966) | def register_convolution(self,
method register_depthwise_conv2d (line 1027) | def register_depthwise_conv2d(self,
method register_separable_conv2d (line 1086) | def register_separable_conv2d(self,
method register_generic (line 1159) | def register_generic(self,
method register_fully_connected_multi (line 1200) | def register_fully_connected_multi(self, params, inputs, outputs,
method register_conv2d_multi (line 1291) | def register_conv2d_multi(self,
method register_scale_and_shift (line 1379) | def register_scale_and_shift(self,
method register_categorical_predictive_distribution (line 1489) | def register_categorical_predictive_distribution(self,
method register_softmax_cross_entropy_loss (line 1532) | def register_softmax_cross_entropy_loss(self,
method register_normal_predictive_distribution (line 1573) | def register_normal_predictive_distribution(self,
method register_squared_error_loss (line 1618) | def register_squared_error_loss(self,
method register_multi_bernoulli_predictive_distribution (line 1657) | def register_multi_bernoulli_predictive_distribution(self,
method register_sigmoid_cross_entropy_loss (line 1701) | def register_sigmoid_cross_entropy_loss(self,
method make_or_get_factor (line 1741) | def make_or_get_factor(self, cls, args):
method as_default (line 1773) | def as_default(self):
FILE: kfac/python/ops/linear_operator.py
class LinearOperatorExtras (line 29) | class LinearOperatorExtras(object): # pylint: disable=missing-docstring
method matmul (line 31) | def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): ...
method matmul_right (line 47) | def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matm...
class LinearOperatorFullMatrix (line 66) | class LinearOperatorFullMatrix(LinearOperatorExtras, # pylint: disable=...
method _matmul_right (line 69) | def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
method _matmul_sparse (line 73) | def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
method _matmul_right_sparse (line 76) | def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
class LinearOperatorDiag (line 81) | class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missin...
method _matmul_right (line 84) | def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
method _matmul_sparse (line 89) | def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
method _matmul_right_sparse (line 94) | def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
FILE: kfac/python/ops/loss_functions.py
class LossFunction (line 30) | class LossFunction(object):
method targets (line 42) | def targets(self):
method inputs (line 51) | def inputs(self):
method evaluate (line 55) | def evaluate(self):
method _evaluate (line 65) | def _evaluate(self, targets):
method multiply_ggn (line 77) | def multiply_ggn(self, vector):
method multiply_ggn_factor (line 94) | def multiply_ggn_factor(self, vector):
method multiply_ggn_factor_transpose (line 116) | def multiply_ggn_factor_transpose(self, vector):
method multiply_ggn_factor_replicated_one_hot (line 138) | def multiply_ggn_factor_replicated_one_hot(self, index):
method ggn_factor_inner_shape (line 165) | def ggn_factor_inner_shape(self):
method ggn_factor_inner_static_shape (line 170) | def ggn_factor_inner_static_shape(self):
method dtype (line 175) | def dtype(self):
class NegativeLogProbLoss (line 182) | class NegativeLogProbLoss(LossFunction):
method __init__ (line 185) | def __init__(self, seed=None):
method inputs (line 190) | def inputs(self):
method params (line 194) | def params(self):
method multiply_fisher (line 199) | def multiply_fisher(self, vector):
method multiply_fisher_factor (line 213) | def multiply_fisher_factor(self, vector):
method multiply_fisher_factor_transpose (line 237) | def multiply_fisher_factor_transpose(self, vector):
method multiply_fisher_factor_replicated_one_hot (line 261) | def multiply_fisher_factor_replicated_one_hot(self, index):
method fisher_factor_inner_shape (line 290) | def fisher_factor_inner_shape(self):
method fisher_factor_inner_static_shape (line 295) | def fisher_factor_inner_static_shape(self):
method sample (line 300) | def sample(self, seed):
method evaluate_on_sample (line 304) | def evaluate_on_sample(self, seed=None):
class NaturalParamsNegativeLogProbLoss (line 320) | class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
method multiply_ggn (line 331) | def multiply_ggn(self, vector):
method multiply_ggn_factor (line 334) | def multiply_ggn_factor(self, vector):
method multiply_ggn_factor_transpose (line 337) | def multiply_ggn_factor_transpose(self, vector):
method multiply_ggn_factor_replicated_one_hot (line 340) | def multiply_ggn_factor_replicated_one_hot(self, index):
method ggn_factor_inner_shape (line 344) | def ggn_factor_inner_shape(self):
method ggn_factor_inner_static_shape (line 348) | def ggn_factor_inner_static_shape(self):
class DistributionNegativeLogProbLoss (line 352) | class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
method __init__ (line 355) | def __init__(self, seed=None):
method dist (line 359) | def dist(self):
method _evaluate (line 363) | def _evaluate(self, targets):
method sample (line 366) | def sample(self, seed):
class NormalMeanNegativeLogProbLoss (line 370) | class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
method __init__ (line 384) | def __init__(self, mean, var=0.5, targets=None, seed=None):
method targets (line 393) | def targets(self):
method dist (line 397) | def dist(self):
method params (line 401) | def params(self):
method multiply_fisher (line 404) | def multiply_fisher(self, vector):
method multiply_fisher_factor (line 407) | def multiply_fisher_factor(self, vector):
method multiply_fisher_factor_transpose (line 410) | def multiply_fisher_factor_transpose(self, vector):
method multiply_fisher_factor_replicated_one_hot (line 413) | def multiply_fisher_factor_replicated_one_hot(self, index):
method fisher_factor_inner_shape (line 422) | def fisher_factor_inner_shape(self):
method fisher_factor_inner_static_shape (line 426) | def fisher_factor_inner_static_shape(self):
class NormalMeanVarianceNegativeLogProbLoss (line 430) | class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbL...
method __init__ (line 448) | def __init__(self, mean, variance, targets=None, seed=None):
method targets (line 457) | def targets(self):
method dist (line 461) | def dist(self):
method params (line 466) | def params(self):
method _concat (line 469) | def _concat(self, mean, variance):
method _split (line 472) | def _split(self, params):
method _fisher_mean (line 476) | def _fisher_mean(self):
method _fisher_mean_factor (line 480) | def _fisher_mean_factor(self):
method _fisher_var (line 484) | def _fisher_var(self):
method _fisher_var_factor (line 488) | def _fisher_var_factor(self):
method multiply_fisher (line 491) | def multiply_fisher(self, vecs):
method multiply_fisher_factor (line 495) | def multiply_fisher_factor(self, vecs):
method multiply_fisher_factor_transpose (line 500) | def multiply_fisher_factor_transpose(self, vecs):
method multiply_fisher_factor_replicated_one_hot (line 505) | def multiply_fisher_factor_replicated_one_hot(self, index):
method fisher_factor_inner_shape (line 528) | def fisher_factor_inner_shape(self):
method fisher_factor_inner_static_shape (line 533) | def fisher_factor_inner_static_shape(self):
method multiply_ggn (line 537) | def multiply_ggn(self, vector):
method multiply_ggn_factor (line 540) | def multiply_ggn_factor(self, vector):
method multiply_ggn_factor_transpose (line 543) | def multiply_ggn_factor_transpose(self, vector):
method multiply_ggn_factor_replicated_one_hot (line 546) | def multiply_ggn_factor_replicated_one_hot(self, index):
method ggn_factor_inner_shape (line 550) | def ggn_factor_inner_shape(self):
method ggn_factor_inner_static_shape (line 554) | def ggn_factor_inner_static_shape(self):
class CategoricalLogitsNegativeLogProbLoss (line 558) | class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
method __init__ (line 576) | def __init__(self, logits, targets=None, seed=None):
method targets (line 591) | def targets(self):
method dist (line 595) | def dist(self):
method _probs (line 599) | def _probs(self):
method _sqrt_probs (line 603) | def _sqrt_probs(self):
method params (line 607) | def params(self):
method multiply_fisher (line 610) | def multiply_fisher(self, vector):
method multiply_fisher_factor (line 615) | def multiply_fisher_factor(self, vector):
method multiply_fisher_factor_transpose (line 621) | def multiply_fisher_factor_transpose(self, vector):
method multiply_fisher_factor_replicated_one_hot (line 627) | def multiply_fisher_factor_replicated_one_hot(self, index):
method fisher_factor_inner_shape (line 637) | def fisher_factor_inner_shape(self):
method fisher_factor_inner_static_shape (line 641) | def fisher_factor_inner_static_shape(self):
class MultiBernoulliNegativeLogProbLoss (line 645) | class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
method __init__ (line 660) | def __init__(self, logits, targets=None, seed=None):
method targets (line 666) | def targets(self):
method dist (line 670) | def dist(self):
method _probs (line 674) | def _probs(self):
method params (line 678) | def params(self):
method multiply_fisher (line 681) | def multiply_fisher(self, vector):
method multiply_fisher_factor (line 684) | def multiply_fisher_factor(self, vector):
method multiply_fisher_factor_transpose (line 687) | def multiply_fisher_factor_transpose(self, vector):
method multiply_fisher_factor_replicated_one_hot (line 690) | def multiply_fisher_factor_replicated_one_hot(self, index):
method fisher_factor_inner_shape (line 698) | def fisher_factor_inner_shape(self):
method fisher_factor_inner_static_shape (line 702) | def fisher_factor_inner_static_shape(self):
function insert_slice_in_zeros (line 706) | def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
class OnehotCategoricalLogitsNegativeLogProbLoss (line 741) | class OnehotCategoricalLogitsNegativeLogProbLoss(
method dist (line 750) | def dist(self):
FILE: kfac/python/ops/op_queue.py
class OpQueue (line 25) | class OpQueue(object):
method __init__ (line 34) | def __init__(self, ops, seed=None):
method ops (line 52) | def ops(self):
method next_op (line 56) | def next_op(self, sess):
FILE: kfac/python/ops/optimizer.py
function set_global_constants (line 41) | def set_global_constants(include_damping_in_qmodel_change=None):
class KfacOptimizer (line 49) | class KfacOptimizer(tf.train.GradientDescentOptimizer):
method __init__ (line 52) | def __init__(self,
method get_cov_vars (line 422) | def get_cov_vars(self):
method get_inv_vars (line 426) | def get_inv_vars(self):
method factors (line 431) | def factors(self):
method registered_variables (line 435) | def registered_variables(self):
method layers (line 439) | def layers(self):
method mat_type (line 443) | def mat_type(self):
method damping (line 447) | def damping(self):
method damping_adaptation_interval (line 454) | def damping_adaptation_interval(self):
method learning_rate (line 458) | def learning_rate(self):
method momentum (line 465) | def momentum(self):
method rho (line 472) | def rho(self):
method qmodel_change (line 476) | def qmodel_change(self):
method counter (line 480) | def counter(self):
method params_stats (line 484) | def params_stats(self):
method set_loss (line 487) | def set_loss(self, loss):
method _maybe_print_logging_info (line 493) | def _maybe_print_logging_info(self):
method make_vars_and_create_op_thunks (line 509) | def make_vars_and_create_op_thunks(self):
method create_ops_and_vars_thunks (line 521) | def create_ops_and_vars_thunks(self):
method check_var_list (line 548) | def check_var_list(self, var_list):
method _scale_loss (line 554) | def _scale_loss(loss_value):
method minimize (line 563) | def minimize(self,
method compute_gradients (line 598) | def compute_gradients(self,
method _is_damping_adaptation_time (line 638) | def _is_damping_adaptation_time(self):
method _is_just_after_damping_adaptation_time (line 647) | def _is_just_after_damping_adaptation_time(self):
method _maybe_update_prev_loss (line 653) | def _maybe_update_prev_loss(self):
method maybe_pre_update_adapt_damping (line 673) | def maybe_pre_update_adapt_damping(self):
method _maybe_post_update_adapt_damping (line 704) | def _maybe_post_update_adapt_damping(self):
method apply_gradients (line 716) | def apply_gradients(self, grads_and_vars, *args, **kwargs):
method _add_weight_decay (line 771) | def _add_weight_decay(self, grads_and_vars):
method _squared_fisher_norm (line 783) | def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
method _update_clip_coeff (line 805) | def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
method _clip_updates (line 837) | def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
method _compute_prev_updates (line 859) | def _compute_prev_updates(self, variables):
method _compute_qmodel (line 887) | def _compute_qmodel(self,
method _sub_damping_out_qmodel_change_coeff (line 976) | def _sub_damping_out_qmodel_change_coeff(self):
method _compute_qmodel_hyperparams (line 979) | def _compute_qmodel_hyperparams(self, m, c, b, fixed_mu=None):
method _compute_approx_qmodel_change (line 1106) | def _compute_approx_qmodel_change(self, updates_and_vars, grads_and_va...
method _maybe_update_qmodel_change (line 1132) | def _maybe_update_qmodel_change(self, qmodel_change_thunk):
method _multiply_preconditioner (line 1155) | def _multiply_preconditioner(self, vecs_and_vars):
method _get_qmodel_quantities (line 1158) | def _get_qmodel_quantities(self, grads_and_vars):
method _compute_raw_update_steps (line 1175) | def _compute_raw_update_steps(self, grads_and_vars):
method _update_velocities (line 1290) | def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
method _get_current_loss (line 1317) | def _get_current_loss(self):
method _get_prev_loss (line 1323) | def _get_prev_loss(self):
method _update_damping (line 1326) | def _update_damping(self):
function _two_by_two_solve (line 1367) | def _two_by_two_solve(m, vec):
function _eval_quadratic_no_c (line 1389) | def _eval_quadratic_no_c(m, vec):
function _eval_quadratic (line 1393) | def _eval_quadratic(m, c, vec):
FILE: kfac/python/ops/placement.py
function _make_thunk_on_device (line 30) | def _make_thunk_on_device(func, device):
class RoundRobinPlacementMixin (line 37) | class RoundRobinPlacementMixin(object):
method __init__ (line 40) | def __init__(self, cov_devices=None, inv_devices=None, trans_devices=N...
method _place_and_compute_transformation_thunks (line 62) | def _place_and_compute_transformation_thunks(self, thunks, params_list):
method create_ops_and_vars_thunks (line 90) | def create_ops_and_vars_thunks(self, scope=None):
class ReplicaRoundRobinPlacementMixin (line 169) | class ReplicaRoundRobinPlacementMixin(object):
method __init__ (line 184) | def __init__(self, distribute_transformations=True, **kwargs):
method _place_and_compute_transformation_thunks (line 205) | def _place_and_compute_transformation_thunks(self, thunks, params_list):
method create_ops_and_vars_thunks (line 225) | def create_ops_and_vars_thunks(self, scope=None):
FILE: kfac/python/ops/tensormatch/graph_matcher.py
function _any (line 65) | def _any(itr):
function _all (line 73) | def _all(itr):
function is_seq (line 84) | def is_seq(obj):
function is_nonempty_seq (line 88) | def is_nonempty_seq(obj):
function is_empty_seq (line 92) | def is_empty_seq(obj):
function is_element_pattern (line 102) | def is_element_pattern(pat):
function element_name (line 106) | def element_name(pat):
function element_restrictions (line 110) | def element_restrictions(pat):
function is_choice_pattern (line 114) | def is_choice_pattern(pat):
function choice_patterns (line 118) | def choice_patterns(pat):
function is_list_pattern (line 122) | def is_list_pattern(pat):
function list_patterns (line 126) | def list_patterns(pat):
function is_not_pattern (line 130) | def is_not_pattern(pat):
function negated_pattern (line 134) | def negated_pattern(pat):
function is_any_pattern (line 138) | def is_any_pattern(pat):
function is_any_noconsume_pattern (line 142) | def is_any_noconsume_pattern(pat):
function is_internal_node_pattern (line 146) | def is_internal_node_pattern(pat):
function internal_node_pattern (line 154) | def internal_node_pattern(pat):
function internal_node_input_pattern (line 158) | def internal_node_input_pattern(pat):
function internal_node_output_pattern (line 165) | def internal_node_output_pattern(pat):
function internal_patterns (line 172) | def internal_patterns(pat):
function match_eqv (line 180) | def match_eqv(pattern):
function match_any (line 186) | def match_any(data, bindings, consumed, succeed):
function match_any_noconsume (line 194) | def match_any_noconsume(data, bindings, consumed, succeed): # pylint: d...
function match_element (line 201) | def match_element(variable_name, restrictions):
function match_choice (line 215) | def match_choice(*match_combinators):
function match_list (line 222) | def match_list(*match_combinators):
function match_not (line 253) | def match_not(match_combinator):
function match_internal (line 261) | def match_internal(*match_combinators):
class PatternEvaluator (line 275) | class PatternEvaluator(object):
method __init__ (line 278) | def __init__(self, default_operation=None):
method defhandler (line 282) | def defhandler(self, predicate, handler):
method __call__ (line 285) | def __call__(self, pat):
function expand_thunks (line 320) | def expand_thunks(pat):
function matcher (line 345) | def matcher(pattern):
function all_matcher (line 352) | def all_matcher(pattern):
function matcher_with_consumed (line 364) | def matcher_with_consumed(pattern):
FILE: kfac/python/ops/tensormatch/graph_patterns.py
function Op (line 29) | def Op(name=None):
function Tensor (line 33) | def Tensor(name=None):
function Variable (line 37) | def Variable(name=None):
function Const (line 41) | def Const(name=None):
function Placeholder (line 45) | def Placeholder(name=None):
function BatchNorm (line 62) | def BatchNorm(in_pattern=Tensor('in'),
function FusedBatchNormOutput (line 81) | def FusedBatchNormOutput(in_pattern=Tensor('in'),
function ScaleAndShift (line 97) | def ScaleAndShift(in_pattern=Tensor('in'),
function Affine (line 116) | def Affine(in_pattern=Tensor('in'),
function Embed (line 133) | def Embed(in_pattern=Tensor('in'),
function Layer (line 150) | def Layer(in_pattern=Tensor('in'), **kwargs):
function LayerWithBatchNorm (line 157) | def LayerWithBatchNorm(in_pattern=Tensor('in')):
FILE: kfac/python/ops/tensormatch/graph_search.py
class RecordType (line 30) | class RecordType(enum.Enum):
class AmbiguousRegistrationError (line 37) | class AmbiguousRegistrationError(Exception):
class MatchRecord (line 41) | class MatchRecord(object):
method __init__ (line 44) | def __init__(self, record_type, params, tensor_set, data=None):
function ensure_sequence (line 63) | def ensure_sequence(obj):
function record_affine_from_bindings (line 71) | def record_affine_from_bindings(bindings, consumed_tensors,
function record_scale_and_shift_from_bindings (line 161) | def record_scale_and_shift_from_bindings(bindings, consumed_tensors,
function record_batch_norm_from_bindings (line 204) | def record_batch_norm_from_bindings(bindings, consumed_tensors,
function register_layers (line 251) | def register_layers(layer_collection, varlist, batch_size=None):
function register_subgraph_layers (line 330) | def register_subgraph_layers(layer_collection,
function filter_user_registered_records (line 471) | def filter_user_registered_records(record_list_dict, user_registered_var...
function filter_grouped_variable_records (line 482) | def filter_grouped_variable_records(layer_collection, record_list_dict):
function filter_subgraph_records (line 496) | def filter_subgraph_records(record_list_dict):
function filter_records (line 530) | def filter_records(layer_collection, record_list_dict,
function register_records (line 580) | def register_records(layer_collection,
FILE: kfac/python/ops/tensormatch/tensorflow_graph_util.py
function is_op (line 35) | def is_op(node):
function is_tensor (line 39) | def is_tensor(node):
function is_var (line 47) | def is_var(node):
function is_const (line 64) | def is_const(node):
function is_placeholder (line 68) | def is_placeholder(node):
function is_leaf (line 72) | def is_leaf(node):
function is_identity (line 76) | def is_identity(node):
function op_type_is (line 85) | def op_type_is(typename):
function reduce_identity_ops (line 93) | def reduce_identity_ops(node):
function expand_inputs (line 106) | def expand_inputs(node):
function expand_outputs (line 120) | def expand_outputs(node):
function make_op_pattern (line 131) | def make_op_pattern(typename):
function import_ops_no_clobber (line 148) | def import_ops_no_clobber(dct, op_names):
FILE: kfac/python/ops/utils.py
function smart_assign (line 38) | def smart_assign(variable, value, assign_fn=tf.assign,
function smart_cond (line 91) | def smart_cond(predicate, true_fn, false_fn, name=None):
function set_global_constants (line 158) | def set_global_constants(posdef_inv_method=None, tf_replicator=None):
class SequenceDict (line 169) | class SequenceDict(object):
method __init__ (line 172) | def __init__(self, iterable=None):
method __getitem__ (line 175) | def __getitem__(self, key_or_keys):
method __setitem__ (line 181) | def __setitem__(self, key_or_keys, val_or_vals):
method items (line 188) | def items(self):
function tensors_to_column (line 192) | def tensors_to_column(tensors):
function column_to_tensors (line 208) | def column_to_tensors(tensors_template, colvec):
function kronecker_product (line 238) | def kronecker_product(mat1, mat2):
function layer_params_to_mat2d (line 247) | def layer_params_to_mat2d(vector):
function mat2d_to_layer_params (line 270) | def mat2d_to_layer_params(vector_template, mat2d):
function posdef_inv (line 294) | def posdef_inv(tensor, damping):
function posdef_inv_matrix_inverse (line 301) | def posdef_inv_matrix_inverse(tensor, identity, damping):
function posdef_inv_cholesky (line 306) | def posdef_inv_cholesky(tensor, identity, damping):
function posdef_inv_eig (line 312) | def posdef_inv_eig(tensor, identity, damping):
function posdef_eig (line 325) | def posdef_eig(mat):
function posdef_eig_svd (line 330) | def posdef_eig_svd(mat):
function posdef_eig_self_adjoint (line 337) | def posdef_eig_self_adjoint(mat):
function cholesky (line 351) | def cholesky(tensor, damping):
class SubGraph (line 358) | class SubGraph(object):
method __init__ (line 362) | def __init__(self, outputs):
method _iter_add (line 369) | def _iter_add(self, root):
method is_member (line 384) | def is_member(self, node):
method variable_uses (line 388) | def variable_uses(self, var):
method filter_list (line 431) | def filter_list(self, node_list):
function preferred_int_dtype (line 440) | def preferred_int_dtype():
function generate_random_signs (line 449) | def generate_random_signs(shape, dtype=tf.float32):
class MirroredVariableWrapper (line 463) | class MirroredVariableWrapper(object):
method __init__ (line 465) | def __init__(self, var):
method __getattr__ (line 468) | def __getattr__(self, name):
function _as_list (line 481) | def _as_list(x):
function fwd_gradients (line 485) | def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None,
function get_tf_replicator (line 518) | def get_tf_replicator():
function is_tpu_replicated (line 522) | def is_tpu_replicated():
function is_replicated (line 531) | def is_replicated():
function get_num_replicas (line 538) | def get_num_replicas():
function get_replica_id (line 561) | def get_replica_id():
function all_sum (line 588) | def all_sum(structure, name=None):
function all_average (line 624) | def all_average(structure, name=None):
function map_gather (line 652) | def map_gather(thunks, name=None):
function ensure_sequence (line 711) | def ensure_sequence(obj):
function batch_execute (line 719) | def batch_execute(global_step, thunks, batch_size, name=None):
function extract_convolution_patches (line 786) | def extract_convolution_patches(inputs,
function extract_pointwise_conv2d_patches (line 858) | def extract_pointwise_conv2d_patches(inputs,
function is_data_format_channel_last (line 904) | def is_data_format_channel_last(data_format):
function matmul_sparse_dense (line 911) | def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=...
function matmul_diag_sparse (line 943) | def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid...
class AccumulatorVariable (line 969) | class AccumulatorVariable(object):
method __init__ (line 977) | def __init__(self, name, shape, dtype):
method accumulate (line 1007) | def accumulate(self, value):
method value (line 1014) | def value(self):
method read_value_and_reset (line 1018) | def read_value_and_reset(self):
method reset (line 1025) | def reset(self):
class PartitionedTensor (line 1035) | class PartitionedTensor(object):
method __init__ (line 1038) | def __init__(self, tensors):
method shape (line 1073) | def shape(self):
method get_shape (line 1079) | def get_shape(self):
method dtype (line 1083) | def dtype(self):
method one_hot_depth (line 1087) | def one_hot_depth(self):
method __str__ (line 1090) | def __str__(self):
method __hash__ (line 1094) | def __hash__(self):
method __eq__ (line 1097) | def __eq__(self, other):
method __ne__ (line 1102) | def __ne__(self, other):
method __getitem__ (line 1105) | def __getitem__(self, key):
method as_tensor (line 1108) | def as_tensor(self, dtype=None, name=None, as_ref=False):
method device (line 1115) | def device(self):
function _check_match_lists_of_pairs (line 1133) | def _check_match_lists_of_pairs(list1, list2):
function sprod (line 1140) | def sprod(scalar, list_):
function sprod_p (line 1145) | def sprod_p(scalar, list_):
function sum_ (line 1150) | def sum_(list1, list2):
function sum_p (line 1155) | def sum_p(list1, list2):
function ip (line 1162) | def ip(list1, list2):
function ip_p (line 1168) | def ip_p(list1, list2):
function assert_variables_match_pairs_list (line 1176) | def assert_variables_match_pairs_list(a_and_vars,
function multiline_print (line 1208) | def multiline_print(lists):
function get_shape (line 1226) | def get_shape(tensor):
function cls_name (line 1238) | def cls_name(obj):
function is_reference_variable (line 1242) | def is_reference_variable(x):
class MovingAverageVariable (line 1248) | class MovingAverageVariable(object):
method __init__ (line 1257) | def __init__(self, name, shape, dtype, initializer=tf.zeros_initialize...
method dtype (line 1291) | def dtype(self):
method value (line 1295) | def value(self):
method add_to_average (line 1301) | def add_to_average(self, value, decay=1.0, weight=1.0):
method reset (line 1322) | def reset(self):
function num_conv_locations (line 1329) | def num_conv_locations(input_shape, filter_shape, strides, padding):
Condensed preview — 69 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,052K chars).
[
{
"path": ".travis.yml",
"chars": 418,
"preview": "language: python\npython:\n - \"3.6\"\nenv:\n matrix:\n - TF_VERSION=\"1.15\"\ninstall:\n - pip install -q \"tensorflow==$TF_V"
},
{
"path": "AUTHORS",
"chars": 313,
"preview": "# This is the official list of TensorFlow authors for copyright purposes.\n# This file is distinct from the CONTRIBUTORS "
},
{
"path": "LICENSE",
"chars": 11416,
"preview": "Copyright 2019 The TensorFlow Authors. All rights reserved.\n\n Apache License\n "
},
{
"path": "README.md",
"chars": 1035,
"preview": "# K-FAC: Kronecker-Factored Approximate Curvature\n\n[](https:"
},
{
"path": "docs/applications.md",
"chars": 14,
"preview": "Coming Soon..\n"
},
{
"path": "docs/contact.md",
"chars": 1003,
"preview": "Topic | Contact\n-------------------- | ---------------------\n**Questions** | kfac-users@google.com"
},
{
"path": "docs/examples/auto_damp.md",
"chars": 2352,
"preview": "# Automatic tuning of damping parameter.\n\n## Table of Contents\n\n* [1. Cached Reader](#1-cached-reader)\n* [2. Build o"
},
{
"path": "docs/examples/convolutional.md",
"chars": 8026,
"preview": "# Convolutional\n\n## Table of Contents\n\n* [Build the Model](#build-the-model)\n* [Register the layers and loss](#regis"
},
{
"path": "docs/examples/distributed_training.md",
"chars": 6025,
"preview": "# Distributed Training\n\n## Table of Contents\n\n* [Register the layers](#register-the-layers)\n* [Build the optimizer]("
},
{
"path": "docs/examples/parameters.md",
"chars": 5173,
"preview": "# K-FAC Parameters.\n\n## Table of Contents\n\n* [Damping](#damping)\n* [Learning Rate](#learning-rate)\n* [Subsample co"
},
{
"path": "docs/index.md",
"chars": 4256,
"preview": "# Home\n\nKronecker factored approximate curvature\n\n**K-FAC in TensorFlow** is an implementation of K-FAC, an approximate\n"
},
{
"path": "docs/papers.md",
"chars": 919,
"preview": "* Martens, James, and Roger Grosse.\n [\"Optimizing neural networks with Kronecker-factored approximate curvature.\"]["
},
{
"path": "docs/sitemap.md",
"chars": 755,
"preview": "* [Home](https://github.com/tensorflow/kfac/tree/master/docs/index.md)\n* User Guide\n * [Convolutional](https://"
},
{
"path": "kfac/__init__.py",
"chars": 2237,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "kfac/examples/autoencoder_mnist.py",
"chars": 24377,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/autoencoder_mnist_tpu_estimator.py",
"chars": 9093,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/autoencoder_mnist_tpu_strategy.py",
"chars": 8402,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/classifier_mnist.py",
"chars": 21345,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/classifier_mnist_tpu_estimator.py",
"chars": 8610,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/convnet.py",
"chars": 33436,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/keras/KFAC_vs_Adam_Experiment.md",
"chars": 3903,
"preview": "# KFAC vs Adam Experiment\n\n## Set Up\n\nWe compare KFAC and Adam on a RESNET-20 on the CIFAR10 dataset. We split CIFAR10\ni"
},
{
"path": "kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10.ipynb",
"chars": 21591,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"colab_type\": \"text\",\n \"id\": \"_DD"
},
{
"path": "kfac/examples/keras/KFAC_vs_Adam_on_CIFAR10_TPU.ipynb",
"chars": 32644,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"colab_type\": \"text\",\n \"id\": \"_DD"
},
{
"path": "kfac/examples/mnist.py",
"chars": 4291,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/examples/rnn_mnist.py",
"chars": 14213,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "kfac/python/keras/README.md",
"chars": 7114,
"preview": "# K-FAC for Keras\n\n**K-FAC for Keras** is an implementation of K-FAC, an approximate second-order\noptimization method, i"
},
{
"path": "kfac/python/keras/__init__.py",
"chars": 872,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/keras/callbacks.py",
"chars": 7598,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/keras/optimizers.py",
"chars": 19535,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/keras/saving_utils.py",
"chars": 6320,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/keras/utils.py",
"chars": 22393,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/data_reader_test.py",
"chars": 2617,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/estimator_test.py",
"chars": 11665,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/graph_search_test.py",
"chars": 34509,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/keras_callbacks_test.py",
"chars": 9163,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/keras_optimizers_test.py",
"chars": 41363,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/keras_saving_utils_test.py",
"chars": 10744,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/keras_utils_test.py",
"chars": 25212,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/layer_collection_test.py",
"chars": 22448,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/loss_functions_test.py",
"chars": 7205,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/op_queue_test.py",
"chars": 1670,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/optimizer_test.py",
"chars": 7517,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/periodic_inv_cov_update_kfac_opt_test.py",
"chars": 3178,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/kernel_tests/utils_test.py",
"chars": 16451,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "kfac/python/ops/curvature_matrix_vector_products.py",
"chars": 9985,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/estimator.py",
"chars": 29042,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/fisher_blocks.py",
"chars": 70400,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/fisher_factors.py",
"chars": 104730,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/kfac_utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "kfac/python/ops/kfac_utils/async_inv_cov_update_kfac_opt.py",
"chars": 5554,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/kfac_utils/data_reader.py",
"chars": 5162,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/kfac_utils/data_reader_alt.py",
"chars": 4750,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/kfac_utils/periodic_inv_cov_update_kfac_opt.py",
"chars": 9287,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/layer_collection.py",
"chars": 78800,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/linear_operator.py",
"chars": 3500,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/loss_functions.py",
"chars": 24697,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/op_queue.py",
"chars": 2367,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/optimizer.py",
"chars": 61064,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/placement.py",
"chars": 13848,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/tensormatch/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "kfac/python/ops/tensormatch/graph_matcher.py",
"chars": 11056,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/tensormatch/graph_patterns.py",
"chars": 6062,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/tensormatch/graph_search.py",
"chars": 34206,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/tensormatch/tensorflow_graph_util.py",
"chars": 4566,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "kfac/python/ops/utils.py",
"chars": 48544,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
},
{
"path": "setup.py",
"chars": 1961,
"preview": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"L"
}
]
About this extraction
This page contains the full source code of the tensorflow/kfac GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 69 files (989.3 KB), approximately 237.9k tokens, and a symbol index with 1212 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.