Full Code of tensorflow/lattice for AI

master f52258331345 cached
64 files
1.2 MB
303.7k tokens
628 symbols
1 requests
Download .txt
Showing preview only (1,224K chars total). Download the full file or copy to clipboard to get everything.
Repository: tensorflow/lattice
Branch: master
Commit: f52258331345
Files: 64
Total size: 1.2 MB

Directory structure:
gitextract_apwg_kre/

├── .gitmodules
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── WORKSPACE
├── docs/
│   ├── _book.yaml
│   ├── _index.yaml
│   ├── build_docs.py
│   ├── install.md
│   ├── overview.md
│   └── tutorials/
│       ├── aggregate_function_models.ipynb
│       ├── keras_layers.ipynb
│       ├── premade_models.ipynb
│       ├── shape_constraints.ipynb
│       └── shape_constraints_for_ethics.ipynb
├── examples/
│   ├── BUILD
│   ├── keras_functional_uci_heart.py
│   └── keras_sequential_uci_heart.py
├── setup.py
└── tensorflow_lattice/
    ├── BUILD
    ├── __init__.py
    ├── layers/
    │   └── __init__.py
    └── python/
        ├── BUILD
        ├── __init__.py
        ├── aggregation_layer.py
        ├── aggregation_test.py
        ├── categorical_calibration_layer.py
        ├── categorical_calibration_lib.py
        ├── categorical_calibration_test.py
        ├── cdf_layer.py
        ├── cdf_test.py
        ├── conditional_cdf.py
        ├── conditional_cdf_test.py
        ├── conditional_pwl_calibration.py
        ├── conditional_pwl_calibration_test.py
        ├── configs.py
        ├── configs_test.py
        ├── internal_utils.py
        ├── internal_utils_test.py
        ├── kronecker_factored_lattice_layer.py
        ├── kronecker_factored_lattice_lib.py
        ├── kronecker_factored_lattice_test.py
        ├── lattice_layer.py
        ├── lattice_lib.py
        ├── lattice_test.py
        ├── linear_layer.py
        ├── linear_lib.py
        ├── linear_test.py
        ├── model_info.py
        ├── parallel_combination_layer.py
        ├── parallel_combination_test.py
        ├── premade.py
        ├── premade_lib.py
        ├── premade_test.py
        ├── pwl_calibration_layer.py
        ├── pwl_calibration_lib.py
        ├── pwl_calibration_test.py
        ├── rtl_layer.py
        ├── rtl_lib.py
        ├── rtl_test.py
        ├── test_utils.py
        ├── utils.py
        └── utils_test.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitmodules
================================================
[submodule "tensorflow"]
	path = tensorflow
	url = https://github.com/tensorflow/tensorflow.git
	branch = r1.3


================================================
FILE: AUTHORS
================================================
# This is the official list of TensorFlow Lattice authors for copyright purposes.
# Names should be added to this file as:
# Name or Organization <email address>
# The email address is not required for organizations.
Google Inc.


================================================
FILE: CONTRIBUTING.md
================================================
<!-- Copyright 2017 The TensorFlow Lattice Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================================-->
# How to Contribute

We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution,
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.

You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.

## Code reviews

All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<!-- Copyright 2020 The TensorFlow Lattice Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================================-->
# TensorFlow Lattice

TensorFlow Lattice is a library that implements constrained and interpretable
lattice based models. It is an implementation of
[Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html)
in [TensorFlow](https://www.tensorflow.org).

The library enables you to inject domain knowledge into
the learning process through common-sense or policy-driven shape constraints.
This is done using a collection of Keras layers that can satisfy constraints
such as monotonicity, convexity and pairwise trust:

* PWLCalibration: piecewise linear calibration of signals.
* CategoricalCalibration: mapping of categorical inputs into real values.
* Lattice: interpolated look-up table implementation.
* Linear: linear function with monotonicity and norm constraints.

The library also provides easy to setup canned estimators for common use cases:

* Calibrated Linear
* Calibrated Lattice
* Random Tiny Lattices (RTL)
* Crystals

With TF Lattice you can use domain knowledge to better extrapolate to the parts
of the input space not covered by the training dataset. This helps avoid
unexpected model behaviour when the serving distribution is different from the
training distribution.

<div align="center">
  <img src="docs/images/model_comparison.png">
</div>

You can install our prebuilt pip package using

```bash
pip install tensorflow-lattice
```


================================================
FILE: WORKSPACE
================================================
# Copyright 2018 The TensorFlow Lattice Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# ==============================================================================

workspace(name = "tensorflow_lattice")


================================================
FILE: docs/_book.yaml
================================================
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Resources
  path: /resources
  is_default: true
  menu:
  - include: /resources/_menu_toc.yaml
  lower_tabs:
    # Subsite tabs
    other:
    - name: Guide & Tutorials
      contents:
      - title: Overview
        path: /lattice/overview
      - title: Install
        path: /lattice/install
      - heading: Tutorials
      - title: Shape Constraints
        path: /lattice/tutorials/shape_constraints
      - title: Ethical Constraints for ML Fairness
        path: /lattice/tutorials/shape_constraints_for_ethics
      - title: Keras Layers and Custom Models
        path: /lattice/tutorials/keras_layers
      - title: Keras Premade Models
        path: /lattice/tutorials/premade_models
      - title: Aggregate Function Models
        path: /lattice/tutorials/aggregate_function_models

    - name: API
      skip_translation: true
      contents:
      - title: All Symbols
        path: /lattice/api_docs/python/tfl/all_symbols
      - include: /lattice/api_docs/python/tfl/_toc.yaml

- include: /_upper_tabs_right.yaml


================================================
FILE: docs/_index.yaml
================================================
book_path: /lattice/_book.yaml
project_path: /lattice/_project.yaml
description: A library for training constrained and interpretable lattice based models. Inject
 domain knowledge into the learning process through constraints on Keras layers.
landing_page:
  custom_css_path: /site-assets/css/style.css
  rows:
  - heading: Flexible, controlled and interpretable ML with lattice based models
    items:
    - classname: devsite-landing-row-50
      description: >
        <p>TensorFlow Lattice is a library that implements constrained and interpretable lattice
        based models. The library enables you to inject domain knowledge into the learning process
        through common-sense or policy-driven
        <a href="./tutorials/shape_constraints">shape constraints</a>. This is done using a
        collection of <a href="./tutorials/keras_layers">Keras layers</a> that can satisfy
        constraints such as monotonicity, convexity and how features interact. The library also
        provides easy to setup <a href="./tutorials/premade_models">premade models</a>.</p>
        <p>With TF Lattice you can use domain knowledge to better extrapolate to the parts of the
        input space not covered by the training dataset. This helps avoid unexpected model behaviour
        when the serving distribution is different from the training distribution.</p>
        <figure>
            <img src="images/model_comparison.png">
        </figure>

      code_block: |
        <pre class = "prettyprint">
        import numpy as np
        import tensorflow as tf
        import tensorflow_lattice as tfl

        model = tf.keras.models.Sequential()
        model.add(
            tfl.layers.ParallelCombination([
                # Monotonic piece-wise linear calibration with bounded output
                tfl.layers.PWLCalibration(
                    monotonicity='increasing',
                    input_keypoints=np.linspace(1., 5., num=20),
                    output_min=0.0,
                    output_max=1.0),
                # Diminishing returns
                tfl.layers.PWLCalibration(
                    monotonicity='increasing',
                    convexity='concave',
                    input_keypoints=np.linspace(0., 200., num=20),
                    output_min=0.0,
                    output_max=2.0),
                # Partially monotonic categorical calibration: calib(0) <= calib(1)
                tfl.layers.CategoricalCalibration(
                    num_buckets=4,
                    output_min=0.0,
                    output_max=1.0,
                    monotonicities=[(0, 1)]),
            ]))
        model.add(
            tfl.layers.Lattice(
                lattice_sizes=[2, 3, 2],
                monotonicities=['increasing', 'increasing', 'increasing'],
                # Trust: model is more responsive to input 0 if input 1 increases
                edgeworth_trusts=(0, 1, 'positive')))
        model.compile(...)
        </pre>

  - classname: devsite-landing-row-cards
    items:
    - heading: "TensorFlow Lattice: Flexible, controlled and interpretable ML"
      image_path: /resources/images/tf-logo-card-16x9.png
      path: https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html
      buttons:
      - label: "Read on the TensorFlow blog"
        path: https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html
    - heading: "TensorFlow Lattice: Control your ML with monotonicity"
      youtube_id: ABBnNjbjv2Q
      buttons:
      - label: Watch the video
        path: https://www.youtube.com/watch?v=ABBnNjbjv2Q
    - heading: "TF Lattice on GitHub"
      image_path: /resources/images/github-card-16x9.png
      path: https://github.com/tensorflow/lattice
      buttons:
      - label: "View on GitHub"
        path: https://github.com/tensorflow/lattice


================================================
FILE: docs/build_docs.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate docs API for TF Lattice.

Example run:

```
python build_docs.py --output_dir=/path/to/output
```
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

from absl import app
from absl import flags

from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api

import tensorflow_lattice as tfl

flags.DEFINE_string('output_dir', '/tmp/tfl_api/',
                    'The path to output the files to')

flags.DEFINE_string(
    'code_url_prefix',
    'https://github.com/tensorflow/lattice/blob/master/tensorflow_lattice',
    'The url prefix for links to code.')

flags.DEFINE_bool('search_hints', True,
                  'Include metadata search hints in the generated files')

flags.DEFINE_string('site_path', 'lattice/api_docs/python',
                    'Path prefix in the _toc.yaml')

FLAGS = flags.FLAGS


def local_definitions_filter(path, parent, children):
  """Filters local imports, except for the tfl.layers module."""
  if path == ('tfl', 'layers'):
    return children
  return public_api.local_definitions_filter(path, parent, children)


def main(_):
  private_map = {
      'tfl': ['python'],
      'tfl.aggregation_layer': ['Aggregation'],
      'tfl.categorical_calibration_layer': ['CategoricalCalibration'],
      'tfl.cdf_layer': ['CDF'],
      'tfl.kronecker_factored_lattice_layer': ['KroneckerFactoredLattice'],
      'tfl.lattice_layer': ['Lattice'],
      'tfl.linear_layer': ['Linear'],
      'tfl.pwl_calibration_layer': ['PWLCalibration'],
      'tfl.parallel_combination_layer': ['ParallelCombination'],
      'tfl.rtl_layer': ['RTL'],
  }
  doc_generator = generate_lib.DocGenerator(
      root_title='TensorFlow Lattice 2.0',
      py_modules=[('tfl', tfl)],
      base_dir=os.path.dirname(tfl.__file__),
      code_url_prefix=FLAGS.code_url_prefix,
      search_hints=FLAGS.search_hints,
      site_path=FLAGS.site_path,
      private_map=private_map,
      callbacks=[local_definitions_filter])

  sys.exit(doc_generator.build(output_dir=FLAGS.output_dir))


if __name__ == '__main__':
  app.run(main)


================================================
FILE: docs/install.md
================================================
# Install TensorFlow Lattice

There are several ways to set up your environment to use TensorFlow Lattice
(TFL).

*   The easiest way to learn and use TFL requires no installation: run the any
    of the tutorials (e.g.
    [premade models](tutorials/premade_models.ipynb)).
*   To use TFL on a local machine, install the `tensorflow-lattice` pip package.
*   If you have a unique machine configuration, you can build the package from
    source.

## Install TensorFlow Lattice using pip

Install using pip.

```shell
pip install --upgrade tensorflow-lattice
```

Note that you will need to have `tf_keras` package installed as well.

## Build from source

Clone the github repo:

```shell
git clone https://github.com/tensorflow/lattice.git
```

Build pip package from source:

```shell
python setup.py sdist bdist_wheel --universal --release
```

Install the package:

```shell
pip install --user --upgrade /path/to/pkg.whl
```


================================================
FILE: docs/overview.md
================================================
# TensorFlow Lattice (TFL)

TensorFlow Lattice is a library that implements flexible, controlled and
interpretable lattice based models. The library enables you to inject domain
knowledge into the learning process through common-sense or policy-driven
[shape constraints](tutorials/shape_constraints.ipynb). This is done using a
collection of [Keras layers](tutorials/keras_layers.ipynb) that can satisfy
constraints such as monotonicity, convexity and pairwise trust. The library also
provides easy to setup [premade models](tutorials/premade_models.ipynb).

## Concepts

This section is a simplified version of the description in
[Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html)
, JMLR 2016.

### Lattices

A *lattice* is an interpolated look-up table that can approximate arbitrary
input-output relationships in your data. It overlaps a regular grid onto your
input space and learns values for the output in the vertices of the grid. For a
test point $x$, $f(x)$ is linearly interpolated from the lattice values
surrounding $x$.

<img src="images/2d_lattice.png" style="display:block; margin:auto;">

The simple example above is a function with 2 input features and 4 parameters:
$\theta=[0, 0.2, 0.4, 1]$, which are the function's values at the corners of the
input space; the rest of the function is interpolated from these parameters.

The function $f(x)$ can capture non-linear interactions between features. You
can think of the lattice parameters as the height of poles set in the ground on
a regular grid, and the resulting function is like cloth pulled tight against
the four poles.

With $D$ features and 2 vertices along each dimension, a regular lattice will
have $2^D$ parameters. To fit a more flexible function, you can specify a
finer-grained lattice over the feature space with more vertices along each
dimension. Lattice regression functions are continuous and piecewise infinitely
differentiable.

### Calibration

Let's say the preceding sample lattice represents a learned *user happiness*
with a suggested local coffee shop calculated using features:

*   coffee price, in range 0 to 20 dollars
*   distance to the user, in range 0 to 30 kilometers

We want our model to learn user happiness with a local coffee shop suggestion.
TensorFlow Lattice models can use *piecewise linear functions* (with
`tfl.layers.PWLCalibration`) to calibrate and normalize the input features to
the range accepted by the lattice: 0.0 to 1.0 in the example lattice above. The
following show examples such calibrations functions with 10 keypoints:

<p align="center">
<img src="images/pwl_calibration_distance.png">
<img src="images/pwl_calibration_price.png">
</p>

It is often a good idea to use the quantiles of the features as input keypoints.
TensorFlow Lattice [premade models](tutorials/premade_models.ipynb) can
automatically set the input keypoints to the feature quantiles.

For categorical features, TensorFlow Lattice provides categorical calibration
(with `tfl.layers.CategoricalCalibration`) with similar output bounding to feed
into a lattice.

### Ensembles

The number of parameters of a lattice layer increases exponentially with the
number of input features, hence not scaling well to very high dimensions. To
overcome this limitation, TensorFlow Lattice offers ensembles of lattices that
combine (average) several *tiny* lattices, which enables the model to grow
linearly in the number of features.

The library provides two variations of these ensembles:

*   **Random Tiny Lattices** (RTL): Each submodel uses a random subset of
    features (with replacement).

*   **Crystals** : The Crystals algorithm first trains a *prefitting* model that
    estimates pairwise feature interactions. It then arranges the final ensemble
    such that features with more non-linear interactions are in the same
    lattices.

## Why TensorFlow Lattice ?

You can find a brief introduction to TensorFlow Lattice in this
[TF Blog post](https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html).

### Interpretability

Since the parameters of each layer are the output of that layer, it is easy to
analyze, understand and debug each part of the model.

### Accurate and Flexible Models

Using fine-grained lattices, you can get *arbitrarily complex* functions with a
single lattice layer. Using multiple layers of calibrators and lattices often
work nicely in practice and can match or outperform DNN models of similar sizes.

### Common-Sense Shape Constraints

Real world training data may not sufficiently represent the run-time data.
Flexible ML solutions such as DNNs or forests often act unexpectedly and even
wildly in parts of the input space not covered by the training data. This
behaviour is especially problematic when policy or fairness constraints can be
violated.

<img src="images/model_comparison.png" style="display:block; margin:auto;">

Even though common forms of regularization can result in more sensible
extrapolation, standard regularizers cannot guarantee reasonable model behaviour
across the entire input space, especially with high-dimensional inputs.
Switching to simpler models with more controlled and predictable behaviour can
come at a severe cost to the model accuracy.

TF Lattice makes it possible to keep using flexible models, but provides several
options to inject domain knowledge into the learning process through
semantically meaningful common-sense or policy-driven
[shape constraints](tutorials/shape_constraints.ipynb):

*   **Monotonicity**: You can specify that the output should only
    increase/decrease with respect to an input. In our example, you may want to
    specify that increased distance to a coffee shop should only decrease the
    predicted user preference.

<p align="center">
<img src="images/linear_fit.png">
<img src="images/flexible_fit.png">
<img src="images/regularized_fit.png">
<img src="images/monotonic_fit.png">
</p>

*   **Convexity/Concavity**: You can specify that the function shape can be
    convex or concave. Mixed with monotonicity, this can force the function to
    represent diminishing returns with respect to a given feature.

*   **Unimodality**: You can specify that the function should have a unique peak
    or unique valley. This lets you represent functions that have a *sweet spot*
    with respect to a feature.

*   **Pairwise trust**: This constraint works on a pair of features and suggests
    that one input feature semantically reflects trust in another feature. For
    example, higher number of reviews makes you more confident in the average
    star rating of a restaurant. The model will be more sensitive with respect
    to the star rating (i.e. will have a larger slope with respect to the
    rating) when the number of reviews is higher.

### Controlled Flexibility with Regularizers

In addition to shape constraints, TensorFlow lattice provides a number of
regularizers to control the flexibility and smoothness of the function for each
layer.

*   **Laplacian Regularizer**: Outputs of the lattice/calibration
    vertices/keypoints are regularized towards the values of their respective
    neighbors. This results in a *flatter* function.

*   **Hessian Regularizer**: This penalizes the first derivative of the PWL
    calibration layer to make the function *more linear*.

*   **Wrinkle Regularizer**: This penalizes the second derivative of the PWL
    calibration layer to avoid sudden changes in the curvature. It makes the
    function smoother.

*   **Torsion Regularizer**: Outputs of the lattice will be regularized towards
    preventing torsion among the features. In other words, the model will be
    regularized towards independence between the contributions of the features.

### Mix and match with other Keras layers

You can use TF Lattice layers in combination with other Keras layers to
construct partially constrained or regularized models. For example, lattice or
PWL calibration layers can be used at the last layer of deeper networks that
include embeddings or other Keras layers.

## Papers

*   [Deontological Ethics By Monotonicity Shape Constraints](https://arxiv.org/abs/2001.11990),
    Serena Wang, Maya Gupta, International Conference on Artificial Intelligence
    and Statistics (AISTATS), 2020
*   [Shape Constraints for Set Functions](http://proceedings.mlr.press/v97/cotter19a.html),
    Andrew Cotter, Maya Gupta, H. Jiang, Erez Louidor, Jim Muller, Taman
    Narayan, Serena Wang, Tao Zhu. International Conference on Machine Learning
    (ICML), 2019
*   [Diminishing Returns Shape Constraints for Interpretability and
    Regularization](https://papers.nips.cc/paper/7916-diminishing-returns-shape-constraints-for-interpretability-and-regularization),
    Maya Gupta, Dara Bahri, Andrew Cotter, Kevin Canini, Advances in Neural
    Information Processing Systems (NeurIPS), 2018
*   [Deep Lattice Networks and Partial Monotonic Functions](https://research.google.com/pubs/pub46327.html),
    Seungil You, Kevin Canini, David Ding, Jan Pfeifer, Maya R. Gupta, Advances
    in Neural Information Processing Systems (NeurIPS), 2017
*   [Fast and Flexible Monotonic Functions with Ensembles of Lattices](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices),
    Mahdi Milani Fard, Kevin Canini, Andrew Cotter, Jan Pfeifer, Maya Gupta,
    Advances in Neural Information Processing Systems (NeurIPS), 2016
*   [Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html),
    Maya Gupta, Andrew Cotter, Jan Pfeifer, Konstantin Voevodski, Kevin Canini,
    Alexander Mangylov, Wojciech Moczydlowski, Alexander van Esbroeck, Journal
    of Machine Learning Research (JMLR), 2016
*   [Optimized Regression for Efficient Function Evaluation](http://ieeexplore.ieee.org/document/6203580/),
    Eric Garcia, Raman Arora, Maya R. Gupta, IEEE Transactions on Image
    Processing, 2012
*   [Lattice Regression](https://papers.nips.cc/paper/3694-lattice-regression),
    Eric Garcia, Maya Gupta, Advances in Neural Information Processing Systems
    (NeurIPS), 2009

## Tutorials and API docs

For common model architectures, you can use
[Keras premade models](tutorials/premade_models.ipynb). You can also create
custom models using [TF Lattice Keras layers](tutorials/keras_layers.ipynb) or
mix and match with other Keras layers. Check out the
[full API docs](https://www.tensorflow.org/lattice/api_docs/python/tfl) for
details.


================================================
FILE: docs/tutorials/aggregate_function_models.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RYmPh1qB_KO2"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oMRm3czy9tLh"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ooXoR4kx_YL9"
      },
      "source": [
        "# TF Lattice Aggregate Function Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BR6XNYEXEgSU"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/aggregate_function_models\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/aggregate_function_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/aggregate_function_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/aggregate_function_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ZfQWUmfEsyZ"
      },
      "source": [
        "## Overview\n",
        "\n",
        "TFL Premade Aggregate Function Models are quick and easy ways to build TFL `keras.Model` instances for learning complex aggregation functions. This guide outlines the steps needed to construct a TFL Premade Aggregate Function Model and train/test it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L0lgWoB6Gmk1"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ivwKrEdLGphZ"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install -U tensorflow tf-keras tensorflow-lattice  pydot graphviz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VQsRKS4wGrMu"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j41-kd4MGtDS"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import collections\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HlJH1SMx3Vul"
      },
      "outputs": [],
      "source": [
        "# Use Keras 2.\n",
        "version_fn = getattr(tf.keras, \"version\", None)\n",
        "if version_fn and version_fn().startswith(\"3.\"):\n",
        "  import tf_keras as keras\n",
        "else:\n",
        "  keras = tf.keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZHPohKjBIFG5"
      },
      "source": [
        "Downloading the Puzzles dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VjYHpw2dSfHH"
      },
      "outputs": [],
      "source": [
        "train_dataframe = pd.read_csv(\n",
        "    'https://raw.githubusercontent.com/wbakst/puzzles_data/master/train.csv')\n",
        "train_dataframe.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UOsgu3eIEur6"
      },
      "outputs": [],
      "source": [
        "test_dataframe = pd.read_csv(\n",
        "    'https://raw.githubusercontent.com/wbakst/puzzles_data/master/test.csv')\n",
        "test_dataframe.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XG7MPCyzVr22"
      },
      "source": [
        "Extract and convert features and labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bYdJicq5bBuz"
      },
      "outputs": [],
      "source": [
        "# Features:\n",
        "# - star_rating       rating out of 5 stars (1-5)\n",
        "# - word_count        number of words in the review\n",
        "# - is_amazon         1 = reviewed on amazon; 0 = reviewed on artifact website\n",
        "# - includes_photo    if the review includes a photo of the puzzle\n",
        "# - num_helpful       number of people that found this review helpful\n",
        "# - num_reviews       total number of reviews for this puzzle (we construct)\n",
        "#\n",
        "# This ordering of feature names will be the exact same order that we construct\n",
        "# our model to expect.\n",
        "feature_names = [\n",
        "    'star_rating', 'word_count', 'is_amazon', 'includes_photo', 'num_helpful',\n",
        "    'num_reviews'\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kx0ZX2HR-4qb"
      },
      "outputs": [],
      "source": [
        "def extract_features(dataframe, label_name):\n",
        "  # First we extract flattened features.\n",
        "  flattened_features = {\n",
        "      feature_name: dataframe[feature_name].values.astype(float)\n",
        "      for feature_name in feature_names[:-1]\n",
        "  }\n",
        "\n",
        "  # Construct mapping from puzzle name to feature.\n",
        "  star_rating = collections.defaultdict(list)\n",
        "  word_count = collections.defaultdict(list)\n",
        "  is_amazon = collections.defaultdict(list)\n",
        "  includes_photo = collections.defaultdict(list)\n",
        "  num_helpful = collections.defaultdict(list)\n",
        "  labels = {}\n",
        "\n",
        "  # Extract each review.\n",
        "  for i in range(len(dataframe)):\n",
        "    row = dataframe.iloc[i]\n",
        "    puzzle_name = row['puzzle_name']\n",
        "    star_rating[puzzle_name].append(float(row['star_rating']))\n",
        "    word_count[puzzle_name].append(float(row['word_count']))\n",
        "    is_amazon[puzzle_name].append(float(row['is_amazon']))\n",
        "    includes_photo[puzzle_name].append(float(row['includes_photo']))\n",
        "    num_helpful[puzzle_name].append(float(row['num_helpful']))\n",
        "    labels[puzzle_name] = float(row[label_name])\n",
        "\n",
        "  # Organize data into list of list of features.\n",
        "  names = list(star_rating.keys())\n",
        "  star_rating = [star_rating[name] for name in names]\n",
        "  word_count = [word_count[name] for name in names]\n",
        "  is_amazon = [is_amazon[name] for name in names]\n",
        "  includes_photo = [includes_photo[name] for name in names]\n",
        "  num_helpful = [num_helpful[name] for name in names]\n",
        "  num_reviews = [[len(ratings)] * len(ratings) for ratings in star_rating]\n",
        "  labels = [labels[name] for name in names]\n",
        "\n",
        "  # Flatten num_reviews\n",
        "  flattened_features['num_reviews'] = [len(reviews) for reviews in num_reviews]\n",
        "\n",
        "  # Convert data into ragged tensors.\n",
        "  star_rating = tf.ragged.constant(star_rating)\n",
        "  word_count = tf.ragged.constant(word_count)\n",
        "  is_amazon = tf.ragged.constant(is_amazon)\n",
        "  includes_photo = tf.ragged.constant(includes_photo)\n",
        "  num_helpful = tf.ragged.constant(num_helpful)\n",
        "  num_reviews = tf.ragged.constant(num_reviews)\n",
        "  labels = tf.constant(labels)\n",
        "\n",
        "  # Now we can return our extracted data.\n",
        "  return (star_rating, word_count, is_amazon, includes_photo, num_helpful,\n",
        "          num_reviews), labels, flattened_features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nd6j_J5CbNiz"
      },
      "outputs": [],
      "source": [
        "train_xs, train_ys, flattened_features = extract_features(train_dataframe, 'Sales12-18MonthsAgo')\n",
        "test_xs, test_ys, _ = extract_features(test_dataframe, 'SalesLastSixMonths')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KfHHhCRsHejl"
      },
      "outputs": [],
      "source": [
        "# Let's define our label minimum and maximum.\n",
        "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\n",
        "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9TwqlRirIhAq"
      },
      "source": [
        "Setting the default values used for training in this guide:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GckmXFzRIhdD"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.1\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "MIDDLE_DIM = 3\n",
        "MIDDLE_LATTICE_SIZE = 2\n",
        "MIDDLE_KEYPOINTS = 16\n",
        "OUTPUT_KEYPOINTS = 8"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TpDKon4oIh2W"
      },
      "source": [
        "## Feature Configs\n",
        "\n",
        "Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\n",
        "\n",
        "Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists. For aggregation models, these features will automaticaly be considered and properly handled as ragged."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_IMwcDh7Xs5n"
      },
      "source": [
        "### Compute Quantiles\n",
        "\n",
        "Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l0uYl9ZpXtW1"
      },
      "outputs": [],
      "source": [
        "def compute_quantiles(features,\n",
        "                      num_keypoints=10,\n",
        "                      clip_min=None,\n",
        "                      clip_max=None,\n",
        "                      missing_value=None):\n",
        "  # Clip min and max if desired.\n",
        "  if clip_min is not None:\n",
        "    features = np.maximum(features, clip_min)\n",
        "    features = np.append(features, clip_min)\n",
        "  if clip_max is not None:\n",
        "    features = np.minimum(features, clip_max)\n",
        "    features = np.append(features, clip_max)\n",
        "  # Make features unique.\n",
        "  unique_features = np.unique(features)\n",
        "  # Remove missing values if specified.\n",
        "  if missing_value is not None:\n",
        "    unique_features = np.delete(unique_features,\n",
        "                                np.where(unique_features == missing_value))\n",
        "  # Compute and return quantiles over unique non-missing feature values.\n",
        "  return np.quantile(\n",
        "      unique_features,\n",
        "      np.linspace(0., 1., num=num_keypoints),\n",
        "      interpolation='nearest').astype(float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9oYZdVeWEhf2"
      },
      "source": [
        "### Defining Our Feature Configs\n",
        "\n",
        "Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rEYlSXhTEmoh"
      },
      "outputs": [],
      "source": [
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='star_rating',\n",
        "        lattice_size=2,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            flattened_features['star_rating'], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='word_count',\n",
        "        lattice_size=2,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            flattened_features['word_count'], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='is_amazon',\n",
        "        lattice_size=2,\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='includes_photo',\n",
        "        lattice_size=2,\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='num_helpful',\n",
        "        lattice_size=2,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            flattened_features['num_helpful'], num_keypoints=5),\n",
        "        # Larger num_helpful indicating more trust in star_rating.\n",
        "        reflects_trust_in=[\n",
        "            tfl.configs.TrustConfig(\n",
        "                feature_name=\"star_rating\", trust_type=\"trapezoid\"),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='num_reviews',\n",
        "        lattice_size=2,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            flattened_features['num_reviews'], num_keypoints=5),\n",
        "    )\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9zoPJRBvPdcH"
      },
      "source": [
        "## Aggregate Function Model\n",
        "\n",
        "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). An aggregate function model is constructed using the [tfl.configs.AggregateFunctionConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/AggregateFunctionConfig). It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-linear calibration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l_4J7EjSPiP3"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the aggregate function model.\n",
        "aggregate_function_model_config = tfl.configs.AggregateFunctionConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    middle_dimension=MIDDLE_DIM,\n",
        "    middle_lattice_size=MIDDLE_LATTICE_SIZE,\n",
        "    middle_calibration=True,\n",
        "    middle_calibration_num_keypoints=MIDDLE_KEYPOINTS,\n",
        "    middle_monotonicity='increasing',\n",
        "    output_min=min_label,\n",
        "    output_max=max_label,\n",
        "    output_calibration=True,\n",
        "    output_calibration_num_keypoints=OUTPUT_KEYPOINTS,\n",
        "    output_initialization=np.linspace(\n",
        "        min_label, max_label, num=OUTPUT_KEYPOINTS))\n",
        "# An AggregateFunction premade model constructed from the given model config.\n",
        "aggregate_function_model = tfl.premade.AggregateFunction(\n",
        "    aggregate_function_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(\n",
        "    aggregate_function_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4F7AwiXgWhe2"
      },
      "source": [
        "The output of each Aggregation layer is the averaged output of a calibrated lattice over the ragged inputs. Here is the model used inside the first Aggregation layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UM7XF6UIWo4T"
      },
      "outputs": [],
      "source": [
        "aggregation_layers = [\n",
        "    layer for layer in aggregate_function_model.layers\n",
        "    if isinstance(layer, tfl.layers.Aggregation)\n",
        "]\n",
        "keras.utils.plot_model(\n",
        "    aggregation_layers[0].model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ohYOftgTZhq"
      },
      "source": [
        "Now, as with any other [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uB9di3-lTfMy"
      },
      "outputs": [],
      "source": [
        "aggregate_function_model.compile(\n",
        "    loss='mae',\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "aggregate_function_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pwZtGDR-Tzur"
      },
      "source": [
        "After training our model, we can evaluate it on our test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RWj1YfubT0NE"
      },
      "outputs": [],
      "source": [
        "print('Test Set Evaluation...')\n",
        "print(aggregate_function_model.evaluate(test_xs, test_ys))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "aggregate_function_models.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1ohMV9lhzSWZq3aH27fBAZ1Oj3wy19PI0",
          "timestamp": 1588637142053
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}


================================================
FILE: docs/tutorials/keras_layers.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZS8z-_KeywY9"
      },
      "source": [
        "# Creating Keras Models with TFL Layers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/keras_layers\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/keras_layers.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/keras_layers.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/keras_layers.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ecLbJCvJSSCd"
      },
      "source": [
        "##Overview\n",
        "\n",
        "You can use TFL Keras layers to construct Keras models with monotonicity and other shape constraints. This example builds and trains a calibrated lattice model for the UCI heart dataset using TFL layers.\n",
        "\n",
        "In a calibrated lattice model, each feature is transformed by a `tfl.layers.PWLCalibration` or a `tfl.layers.CategoricalCalibration` layer and the results are nonlinearly fused using a `tfl.layers.Lattice`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install -U tensorflow tf-keras tensorflow-lattice  pydot graphviz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "pm0LD8iyIZXF"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "from tensorflow import feature_column as fc\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m8TsvLIe4Az-"
      },
      "outputs": [],
      "source": [
        "# Use Keras 2.\n",
        "version_fn = getattr(tf.keras, \"version\", None)\n",
        "if version_fn and version_fn().startswith(\"3.\"):\n",
        "  import tf_keras as keras\n",
        "else:\n",
        "  keras = tf.keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "svPuM6QNxlrH"
      },
      "source": [
        "Downloading the UCI Statlog (Heart) dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "PG3pFtK-IZXM"
      },
      "outputs": [],
      "source": [
        "# UCI Statlog (Heart) dataset.\n",
        "csv_file = keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\n",
        "training_data_df = pd.read_csv(csv_file).sample(\n",
        "    frac=1.0, random_state=41).reset_index(drop=True)\n",
        "training_data_df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nKkAw12SxvGG"
      },
      "source": [
        "Setting the default values used for training in this guide:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "krAJBE-yIZXR"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.1\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 100"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TGfzhPHzpix"
      },
      "source": [
        "## Sequential Keras Model\n",
        "\n",
        "This example creates a Sequential Keras model and only uses TFL layers.\n",
        "\n",
        "Lattice layers expect `input[i]` to be within `[0, lattice_sizes[i] - 1.0]`, so we need to define the lattice sizes ahead of the calibration layers so we can properly specify output range of the calibration layers.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nOQWqPAbQS3o"
      },
      "outputs": [],
      "source": [
        "# Lattice layer expects input[i] to be within [0, lattice_sizes[i] - 1.0], so\n",
        "lattice_sizes = [3, 2, 2, 2, 2, 2, 2]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W3DnEKWvQYXm"
      },
      "source": [
        "We use a `tfl.layers.ParallelCombination` layer to group together calibration layers which have to be executed in parallel in order to be able to create a Sequential model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o_hyk5GkQfl8"
      },
      "outputs": [],
      "source": [
        "combined_calibrators = tfl.layers.ParallelCombination()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BPZsSUZiQiwc"
      },
      "source": [
        "We create a calibration layer for each feature and add it to the parallel combination layer. For numeric features we use `tfl.layers.PWLCalibration`, and for categorical features we use `tfl.layers.CategoricalCalibration`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DXPc6rSGxzFZ"
      },
      "outputs": [],
      "source": [
        "# ############### age ###############\n",
        "calibrator = tfl.layers.PWLCalibration(\n",
        "    # Every PWLCalibration layer must have keypoints of piecewise linear\n",
        "    # function specified. Easiest way to specify them is to uniformly cover\n",
        "    # entire input range by using numpy.linspace().\n",
        "    input_keypoints=np.linspace(\n",
        "        training_data_df['age'].min(), training_data_df['age'].max(), num=5),\n",
        "    # You need to ensure that input keypoints have same dtype as layer input.\n",
        "    # You can do it by setting dtype here or by providing keypoints in such\n",
        "    # format which will be converted to desired tf.dtype by default.\n",
        "    dtype=tf.float32,\n",
        "    # Output range must correspond to expected lattice input range.\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[0] - 1.0,\n",
        ")\n",
        "combined_calibrators.append(calibrator)\n",
        "\n",
        "# ############### sex ###############\n",
        "# For boolean features simply specify CategoricalCalibration layer with 2\n",
        "# buckets.\n",
        "calibrator = tfl.layers.CategoricalCalibration(\n",
        "    num_buckets=2,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[1] - 1.0,\n",
        "    # Initializes all outputs to (output_min + output_max) / 2.0.\n",
        "    kernel_initializer='constant')\n",
        "combined_calibrators.append(calibrator)\n",
        "\n",
        "# ############### cp ###############\n",
        "calibrator = tfl.layers.PWLCalibration(\n",
        "    # Here instead of specifying dtype of layer we convert keypoints into\n",
        "    # np.float32.\n",
        "    input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32),\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[2] - 1.0,\n",
        "    monotonicity='increasing',\n",
        "    # You can specify TFL regularizers as a tuple ('regularizer name', l1, l2).\n",
        "    kernel_regularizer=('hessian', 0.0, 1e-4))\n",
        "combined_calibrators.append(calibrator)\n",
        "\n",
        "# ############### trestbps ###############\n",
        "calibrator = tfl.layers.PWLCalibration(\n",
        "    # Alternatively, you might want to use quantiles as keypoints instead of\n",
        "    # uniform keypoints\n",
        "    input_keypoints=np.quantile(training_data_df['trestbps'],\n",
        "                                np.linspace(0.0, 1.0, num=5)),\n",
        "    dtype=tf.float32,\n",
        "    # Together with quantile keypoints you might want to initialize piecewise\n",
        "    # linear function to have 'equal_slopes' in order for output of layer\n",
        "    # after initialization to preserve original distribution.\n",
        "    kernel_initializer='equal_slopes',\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[3] - 1.0,\n",
        "    # You might consider clamping extreme inputs of the calibrator to output\n",
        "    # bounds.\n",
        "    clamp_min=True,\n",
        "    clamp_max=True,\n",
        "    monotonicity='increasing')\n",
        "combined_calibrators.append(calibrator)\n",
        "\n",
        "# ############### chol ###############\n",
        "calibrator = tfl.layers.PWLCalibration(\n",
        "    # Explicit input keypoint initialization.\n",
        "    input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "    dtype=tf.float32,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[4] - 1.0,\n",
        "    # Monotonicity of calibrator can be decreasing. Note that corresponding\n",
        "    # lattice dimension must have INCREASING monotonicity regardless of\n",
        "    # monotonicity direction of calibrator.\n",
        "    monotonicity='decreasing',\n",
        "    # Convexity together with decreasing monotonicity result in diminishing\n",
        "    # return constraint.\n",
        "    convexity='convex',\n",
        "    # You can specify list of regularizers. You are not limited to TFL\n",
        "    # regularizrs. Feel free to use any :)\n",
        "    kernel_regularizer=[('laplacian', 0.0, 1e-4),\n",
        "                        keras.regularizers.l1_l2(l1=0.001)])\n",
        "combined_calibrators.append(calibrator)\n",
        "\n",
        "# ############### fbs ###############\n",
        "calibrator = tfl.layers.CategoricalCalibration(\n",
        "    num_buckets=2,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[5] - 1.0,\n",
        "    # For categorical calibration layer monotonicity is specified for pairs\n",
        "    # of indices of categories. Output for first category in pair will be\n",
        "    # smaller than output for second category.\n",
        "    #\n",
        "    # Don't forget to set monotonicity of corresponding dimension of Lattice\n",
        "    # layer to '1'.\n",
        "    monotonicities=[(0, 1)],\n",
        "    # This initializer is identical to default one('uniform'), but has fixed\n",
        "    # seed in order to simplify experimentation.\n",
        "    kernel_initializer=keras.initializers.RandomUniform(\n",
        "        minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1))\n",
        "combined_calibrators.append(calibrator)\n",
        "\n",
        "# ############### restecg ###############\n",
        "calibrator = tfl.layers.CategoricalCalibration(\n",
        "    num_buckets=3,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[6] - 1.0,\n",
        "    # Categorical monotonicity can be partial order.\n",
        "    monotonicities=[(0, 1), (0, 2)],\n",
        "    # Categorical calibration layer supports standard Keras regularizers.\n",
        "    kernel_regularizer=keras.regularizers.l1_l2(l1=0.001),\n",
        "    kernel_initializer='constant')\n",
        "combined_calibrators.append(calibrator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "inyNlSBeQyp7"
      },
      "source": [
        "We then create a lattice layer to nonlinearly fuse the outputs of the calibrators.\n",
        "\n",
        "Note that we need to specify the monotonicity of the lattice to be increasing for required dimensions. The composition with the direction of the monotonicity in the calibration will result in the correct end-to-end direction of monotonicity. This includes partial monotonicity of CategoricalCalibration layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DNCc9oBTRo6w"
      },
      "outputs": [],
      "source": [
        "lattice = tfl.layers.Lattice(\n",
        "    lattice_sizes=lattice_sizes,\n",
        "    monotonicities=[\n",
        "        'increasing', 'none', 'increasing', 'increasing', 'increasing',\n",
        "        'increasing', 'increasing'\n",
        "    ],\n",
        "    output_min=0.0,\n",
        "    output_max=1.0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T5q2InayRpDr"
      },
      "source": [
        "We can then create a sequential model using the combined calibrators and lattice layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xX6lroYZQy3L"
      },
      "outputs": [],
      "source": [
        "model = keras.models.Sequential()\n",
        "model.add(combined_calibrators)\n",
        "model.add(lattice)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W3UFxD3fRzIC"
      },
      "source": [
        "Training works the same as any other keras model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2jz4JvI-RzSj"
      },
      "outputs": [],
      "source": [
        "features = training_data_df[[\n",
        "    'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg'\n",
        "]].values.astype(np.float32)\n",
        "target = training_data_df[['target']].values.astype(np.float32)\n",
        "\n",
        "model.compile(\n",
        "    loss=keras.losses.mean_squared_error,\n",
        "    optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE))\n",
        "model.fit(\n",
        "    features,\n",
        "    target,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    validation_split=0.2,\n",
        "    shuffle=False,\n",
        "    verbose=0)\n",
        "\n",
        "model.evaluate(features, target)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RTHoW_5lxwT5"
      },
      "source": [
        "## Functional Keras Model\n",
        "\n",
        "This example uses a functional API for Keras model construction.\n",
        "\n",
        "As mentioned in the previous section, lattice layers expect `input[i]` to be within `[0, lattice_sizes[i] - 1.0]`, so we need to define the lattice sizes ahead of the calibration layers so we can properly specify output range of the calibration layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gJjUYvBuW1qE"
      },
      "outputs": [],
      "source": [
        "# We are going to have 2-d embedding as one of lattice inputs.\n",
        "lattice_sizes = [3, 2, 2, 3, 3, 2, 2]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z03qY5MYW1yT"
      },
      "source": [
        "For each feature, we need to create an input layer followed by a calibration layer. For numeric features we use `tfl.layers.PWLCalibration` and for categorical features we use `tfl.layers.CategoricalCalibration`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DCIUz8apzs0l"
      },
      "outputs": [],
      "source": [
        "model_inputs = []\n",
        "lattice_inputs = []\n",
        "# ############### age ###############\n",
        "age_input = keras.layers.Input(shape=[1], name='age')\n",
        "model_inputs.append(age_input)\n",
        "age_calibrator = tfl.layers.PWLCalibration(\n",
        "    # Every PWLCalibration layer must have keypoints of piecewise linear\n",
        "    # function specified. Easiest way to specify them is to uniformly cover\n",
        "    # entire input range by using numpy.linspace().\n",
        "    input_keypoints=np.linspace(\n",
        "        training_data_df['age'].min(), training_data_df['age'].max(), num=5),\n",
        "    # You need to ensure that input keypoints have same dtype as layer input.\n",
        "    # You can do it by setting dtype here or by providing keypoints in such\n",
        "    # format which will be converted to desired tf.dtype by default.\n",
        "    dtype=tf.float32,\n",
        "    # Output range must correspond to expected lattice input range.\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[0] - 1.0,\n",
        "    monotonicity='increasing',\n",
        "    name='age_calib',\n",
        ")(\n",
        "    age_input)\n",
        "lattice_inputs.append(age_calibrator)\n",
        "\n",
        "# ############### sex ###############\n",
        "# For boolean features simply specify CategoricalCalibration layer with 2\n",
        "# buckets.\n",
        "sex_input = keras.layers.Input(shape=[1], name='sex')\n",
        "model_inputs.append(sex_input)\n",
        "sex_calibrator = tfl.layers.CategoricalCalibration(\n",
        "    num_buckets=2,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[1] - 1.0,\n",
        "    # Initializes all outputs to (output_min + output_max) / 2.0.\n",
        "    kernel_initializer='constant',\n",
        "    name='sex_calib',\n",
        ")(\n",
        "    sex_input)\n",
        "lattice_inputs.append(sex_calibrator)\n",
        "\n",
        "# ############### cp ###############\n",
        "cp_input = keras.layers.Input(shape=[1], name='cp')\n",
        "model_inputs.append(cp_input)\n",
        "cp_calibrator = tfl.layers.PWLCalibration(\n",
        "    # Here instead of specifying dtype of layer we convert keypoints into\n",
        "    # np.float32.\n",
        "    input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32),\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[2] - 1.0,\n",
        "    monotonicity='increasing',\n",
        "    # You can specify TFL regularizers as tuple ('regularizer name', l1, l2).\n",
        "    kernel_regularizer=('hessian', 0.0, 1e-4),\n",
        "    name='cp_calib',\n",
        ")(\n",
        "    cp_input)\n",
        "lattice_inputs.append(cp_calibrator)\n",
        "\n",
        "# ############### trestbps ###############\n",
        "trestbps_input = keras.layers.Input(shape=[1], name='trestbps')\n",
        "model_inputs.append(trestbps_input)\n",
        "trestbps_calibrator = tfl.layers.PWLCalibration(\n",
        "    # Alternatively, you might want to use quantiles as keypoints instead of\n",
        "    # uniform keypoints\n",
        "    input_keypoints=np.quantile(training_data_df['trestbps'],\n",
        "                                np.linspace(0.0, 1.0, num=5)),\n",
        "    dtype=tf.float32,\n",
        "    # Together with quantile keypoints you might want to initialize piecewise\n",
        "    # linear function to have 'equal_slopes' in order for output of layer\n",
        "    # after initialization to preserve original distribution.\n",
        "    kernel_initializer='equal_slopes',\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[3] - 1.0,\n",
        "    # You might consider clamping extreme inputs of the calibrator to output\n",
        "    # bounds.\n",
        "    clamp_min=True,\n",
        "    clamp_max=True,\n",
        "    monotonicity='increasing',\n",
        "    name='trestbps_calib',\n",
        ")(\n",
        "    trestbps_input)\n",
        "lattice_inputs.append(trestbps_calibrator)\n",
        "\n",
        "# ############### chol ###############\n",
        "chol_input = keras.layers.Input(shape=[1], name='chol')\n",
        "model_inputs.append(chol_input)\n",
        "chol_calibrator = tfl.layers.PWLCalibration(\n",
        "    # Explicit input keypoint initialization.\n",
        "    input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[4] - 1.0,\n",
        "    # Monotonicity of calibrator can be decreasing. Note that corresponding\n",
        "    # lattice dimension must have INCREASING monotonicity regardless of\n",
        "    # monotonicity direction of calibrator.\n",
        "    monotonicity='decreasing',\n",
        "    # Convexity together with decreasing monotonicity result in diminishing\n",
        "    # return constraint.\n",
        "    convexity='convex',\n",
        "    # You can specify list of regularizers. You are not limited to TFL\n",
        "    # regularizrs. Feel free to use any :)\n",
        "    kernel_regularizer=[('laplacian', 0.0, 1e-4),\n",
        "                        keras.regularizers.l1_l2(l1=0.001)],\n",
        "    name='chol_calib',\n",
        ")(\n",
        "    chol_input)\n",
        "lattice_inputs.append(chol_calibrator)\n",
        "\n",
        "# ############### fbs ###############\n",
        "fbs_input = keras.layers.Input(shape=[1], name='fbs')\n",
        "model_inputs.append(fbs_input)\n",
        "fbs_calibrator = tfl.layers.CategoricalCalibration(\n",
        "    num_buckets=2,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[5] - 1.0,\n",
        "    # For categorical calibration layer monotonicity is specified for pairs\n",
        "    # of indices of categories. Output for first category in pair will be\n",
        "    # smaller than output for second category.\n",
        "    #\n",
        "    # Don't forget to set monotonicity of corresponding dimension of Lattice\n",
        "    # layer to '1'.\n",
        "    monotonicities=[(0, 1)],\n",
        "    # This initializer is identical to default one ('uniform'), but has fixed\n",
        "    # seed in order to simplify experimentation.\n",
        "    kernel_initializer=keras.initializers.RandomUniform(\n",
        "        minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1),\n",
        "    name='fbs_calib',\n",
        ")(\n",
        "    fbs_input)\n",
        "lattice_inputs.append(fbs_calibrator)\n",
        "\n",
        "# ############### restecg ###############\n",
        "restecg_input = keras.layers.Input(shape=[1], name='restecg')\n",
        "model_inputs.append(restecg_input)\n",
        "restecg_calibrator = tfl.layers.CategoricalCalibration(\n",
        "    num_buckets=3,\n",
        "    output_min=0.0,\n",
        "    output_max=lattice_sizes[6] - 1.0,\n",
        "    # Categorical monotonicity can be partial order.\n",
        "    monotonicities=[(0, 1), (0, 2)],\n",
        "    # Categorical calibration layer supports standard Keras regularizers.\n",
        "    kernel_regularizer=keras.regularizers.l1_l2(l1=0.001),\n",
        "    kernel_initializer='constant',\n",
        "    name='restecg_calib',\n",
        ")(\n",
        "    restecg_input)\n",
        "lattice_inputs.append(restecg_calibrator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fr0k8La_YgQG"
      },
      "source": [
        "We then create a lattice layer to nonlinearly fuse the outputs of the calibrators.\n",
        "\n",
        "Note that we need to specify the monotonicity of the lattice to be increasing for required dimensions. The composition with the direction of the monotonicity in the calibration will result in the correct end-to-end direction of monotonicity. This includes partial monotonicity of `tfl.layers.CategoricalCalibration` layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X15RE0NybNbU"
      },
      "outputs": [],
      "source": [
        "lattice = tfl.layers.Lattice(\n",
        "    lattice_sizes=lattice_sizes,\n",
        "    monotonicities=[\n",
        "        'increasing', 'none', 'increasing', 'increasing', 'increasing',\n",
        "        'increasing', 'increasing'\n",
        "    ],\n",
        "    output_min=0.0,\n",
        "    output_max=1.0,\n",
        "    name='lattice',\n",
        ")(\n",
        "    lattice_inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "31VzsnMCA9dh"
      },
      "source": [
        "To add more flexibility to the model, we add an output calibration layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "efCP3Yx2A9n7"
      },
      "outputs": [],
      "source": [
        "model_output = tfl.layers.PWLCalibration(\n",
        "    input_keypoints=np.linspace(0.0, 1.0, 5),\n",
        "    name='output_calib',\n",
        ")(\n",
        "    lattice)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1SURnNl8bNgw"
      },
      "source": [
        "We can now create a model using the inputs and outputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7gY-VXuYbZLa"
      },
      "outputs": [],
      "source": [
        "model = keras.models.Model(\n",
        "    inputs=model_inputs,\n",
        "    outputs=model_output)\n",
        "keras.utils.plot_model(model, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvFJTs94bZXK"
      },
      "source": [
        "Training works the same as any other keras model. Note that, with our setup, input features are passed as separate tensors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vMQTGbFAYgYS"
      },
      "outputs": [],
      "source": [
        "feature_names = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg']\n",
        "features = np.split(\n",
        "    training_data_df[feature_names].values.astype(np.float32),\n",
        "    indices_or_sections=len(feature_names),\n",
        "    axis=1)\n",
        "target = training_data_df[['target']].values.astype(np.float32)\n",
        "\n",
        "model.compile(\n",
        "    loss=keras.losses.mean_squared_error,\n",
        "    optimizer=keras.optimizers.Adagrad(LEARNING_RATE))\n",
        "model.fit(\n",
        "    features,\n",
        "    target,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    validation_split=0.2,\n",
        "    shuffle=False,\n",
        "    verbose=0)\n",
        "\n",
        "model.evaluate(features, target)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "keras_layers.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}


================================================
FILE: docs/tutorials/premade_models.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HZiF5lbumA7j"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eNj0_BTFk479"
      },
      "source": [
        "# TF Lattice Premade Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T3qE8F5toE28"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/premade_models\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/premade_models.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HEuRMAUOlFZa"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Premade Models are quick and easy ways to build TFL `keras.Model` instances for typical use cases. This guide outlines the steps needed to construct a TFL Premade Model and train/test it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f2--Yq21lhRe"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XizqBCyXky4y"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install -U tensorflow tf-keras tensorflow-lattice  pydot graphviz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2oKJPy5tloOB"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9wZWJJggk4al"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import copy\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k-AAoRho3x5N"
      },
      "outputs": [],
      "source": [
        "# Use Keras 2.\n",
        "version_fn = getattr(tf.keras, \"version\", None)\n",
        "if version_fn and version_fn().startswith(\"3.\"):\n",
        "  import tf_keras as keras\n",
        "else:\n",
        "  keras = tf.keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oyOrtol7mW9r"
      },
      "source": [
        "Setting the default values used for training in this guide:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ns8pH2AnmgAC"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.01\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "PREFITTING_NUM_EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kpJJSS7YmLbG"
      },
      "source": [
        "Downloading the UCI Statlog (Heart) dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AYTcybljmQJm"
      },
      "outputs": [],
      "source": [
        "heart_csv_file = keras.utils.get_file(\n",
        "    'heart.csv',\n",
        "    'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\n",
        "heart_df = pd.read_csv(heart_csv_file)\n",
        "thal_vocab_list = ['normal', 'fixed', 'reversible']\n",
        "heart_df['thal'] = heart_df['thal'].map(\n",
        "    {v: i for i, v in enumerate(thal_vocab_list)})\n",
        "heart_df = heart_df.astype(float)\n",
        "\n",
        "heart_train_size = int(len(heart_df) * 0.8)\n",
        "heart_train_dict = dict(heart_df[:heart_train_size])\n",
        "heart_test_dict = dict(heart_df[heart_train_size:])\n",
        "\n",
        "# This ordering of input features should match the feature configs. If no\n",
        "# feature config relies explicitly on the data (i.e. all are 'quantiles'),\n",
        "# then you can construct the feature_names list by simply iterating over each\n",
        "# feature config and extracting it's name.\n",
        "feature_names = [\n",
        "    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',\n",
        "    'exang', 'oldpeak', 'slope', 'ca', 'thal'\n",
        "]\n",
        "\n",
        "# Since we have some features that manually construct their input keypoints,\n",
        "# we need an index mapping of the feature names.\n",
        "feature_name_indices = {name: index for index, name in enumerate(feature_names)}\n",
        "\n",
        "label_name = 'target'\n",
        "heart_train_xs = [\n",
        "    heart_train_dict[feature_name] for feature_name in feature_names\n",
        "]\n",
        "heart_test_xs = [heart_test_dict[feature_name] for feature_name in feature_names]\n",
        "heart_train_ys = heart_train_dict[label_name]\n",
        "heart_test_ys = heart_test_dict[label_name]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ix2elMrGmiWX"
      },
      "source": [
        "## Feature Configs\n",
        "\n",
        "Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\n",
        "\n",
        "Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePWXuDH7-1i1"
      },
      "source": [
        "### Defining Our Feature Configs\n",
        "\n",
        "Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8y27RmHIrSBn"
      },
      "outputs": [],
      "source": [
        "# Features:\n",
        "# - age\n",
        "# - sex\n",
        "# - cp        chest pain type (4 values)\n",
        "# - trestbps  resting blood pressure\n",
        "# - chol      serum cholestoral in mg/dl\n",
        "# - fbs       fasting blood sugar \u003e 120 mg/dl\n",
        "# - restecg   resting electrocardiographic results (values 0,1,2)\n",
        "# - thalach   maximum heart rate achieved\n",
        "# - exang     exercise induced angina\n",
        "# - oldpeak   ST depression induced by exercise relative to rest\n",
        "# - slope     the slope of the peak exercise ST segment\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      normal; fixed defect; reversable defect\n",
        "#\n",
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "heart_feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='age',\n",
        "        lattice_size=3,\n",
        "        monotonicity='increasing',\n",
        "        # We must set the keypoints manually.\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints='quantiles',\n",
        "        pwl_calibration_clip_max=100,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='sex',\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='cp',\n",
        "        monotonicity='increasing',\n",
        "        # Keypoints that are uniformly spaced.\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        pwl_calibration_input_keypoints=np.linspace(\n",
        "            np.min(heart_train_xs[feature_name_indices['cp']]),\n",
        "            np.max(heart_train_xs[feature_name_indices['cp']]),\n",
        "            num=4),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='chol',\n",
        "        monotonicity='increasing',\n",
        "        # Explicit input keypoints initialization.\n",
        "        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "        # Calibration can be forced to span the full output range by clamping.\n",
        "        pwl_calibration_clamp_min=True,\n",
        "        pwl_calibration_clamp_max=True,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='fbs',\n",
        "        # Partial monotonicity: output(0) \u003c= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='trestbps',\n",
        "        monotonicity='decreasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints='quantiles',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thalach',\n",
        "        monotonicity='decreasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints='quantiles',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='restecg',\n",
        "        # Partial monotonicity: output(0) \u003c= output(1), output(0) \u003c= output(2)\n",
        "        monotonicity=[(0, 1), (0, 2)],\n",
        "        num_buckets=3,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='exang',\n",
        "        # Partial monotonicity: output(0) \u003c= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='oldpeak',\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints='quantiles',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='slope',\n",
        "        # Partial monotonicity: output(0) \u003c= output(1), output(1) \u003c= output(2)\n",
        "        monotonicity=[(0, 1), (1, 2)],\n",
        "        num_buckets=3,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='ca',\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        pwl_calibration_input_keypoints='quantiles',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thal',\n",
        "        # Partial monotonicity:\n",
        "        # output(normal) \u003c= output(fixed)\n",
        "        # output(normal) \u003c= output(reversible)\n",
        "        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n",
        "        num_buckets=3,\n",
        "        # We must specify the vocabulary list in order to later set the\n",
        "        # monotonicities since we used names and not indices.\n",
        "        vocabulary_list=thal_vocab_list,\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-XuAnP_-vyK6"
      },
      "source": [
        "## Set Monotonicities and Keypoints\n",
        "\n",
        "Next we need to make sure to properly set the monotonicities for features where we used a custom vocabulary (such as 'thal' above)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZIn2-EfGv--m"
      },
      "outputs": [],
      "source": [
        "tfl.premade_lib.set_categorical_monotonicities(heart_feature_configs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fHyzh5YHyD5n"
      },
      "source": [
        "Finally we can complete our feature configs by calculating and setting the keypoints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KJ5kKd-lyJhZ"
      },
      "outputs": [],
      "source": [
        "feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\n",
        "    feature_configs=heart_feature_configs, features=heart_train_dict)\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=heart_feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mx50YgWMcxC4"
      },
      "source": [
        "## Calibrated Linear Model\n",
        "\n",
        "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). A calibrated linear model is constructed using the [tfl.configs.CalibratedLinearConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLinearConfig). It applies piecewise-linear and categorical calibration on the input features, followed by a linear combination and an optional output piecewise-linear calibration. When using output calibration or when output bounds are specified, the linear layer will apply weighted averaging on calibrated inputs.\n",
        "\n",
        "This example creates a calibrated linear model on the first 5 features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UvMDJKqTc1vC"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the premade model.\n",
        "linear_model_config = tfl.configs.CalibratedLinearConfig(\n",
        "    feature_configs=heart_feature_configs[:5],\n",
        "    use_bias=True,\n",
        "    output_calibration=True,\n",
        "    output_calibration_num_keypoints=10,\n",
        "    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\n",
        "    output_initialization=np.linspace(-2.0, 2.0, num=10),\n",
        "    regularizer_configs=[\n",
        "        # Regularizer for the output calibrator.\n",
        "        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CalibratedLinear premade model constructed from the given model config.\n",
        "linear_model = tfl.premade.CalibratedLinear(linear_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3MC3-AyX00-A"
      },
      "source": [
        "Now, as with any other [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hPlEK-yG1B-U"
      },
      "outputs": [],
      "source": [
        "linear_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True)],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "linear_model.fit(\n",
        "    heart_train_xs[:5],\n",
        "    heart_train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OG2ua0MGAkoi"
      },
      "source": [
        "After training our model, we can evaluate it on our test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HybGTvXxAoxV"
      },
      "outputs": [],
      "source": [
        "print('Test Set Evaluation...')\n",
        "print(linear_model.evaluate(heart_test_xs[:5], heart_test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jAAJK-wlc15S"
      },
      "source": [
        "## Calibrated Lattice Model\n",
        "\n",
        "A calibrated lattice model is constructed using [tfl.configs.CalibratedLatticeConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLatticeConfig). A calibrated lattice model applies piecewise-linear and categorical calibration on the input features, followed by a lattice model and an optional output piecewise-linear calibration.\n",
        "\n",
        "This example creates a calibrated lattice model on the first 5 features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u7gNcrMtc4Lp"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice model: inputs are calibrated, then combined\n",
        "# non-linearly using a lattice layer.\n",
        "lattice_model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=heart_feature_configs[:5],\n",
        "    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\n",
        "    output_initialization=[-2.0, 2.0],\n",
        "    regularizer_configs=[\n",
        "        # Torsion regularizer applied to the lattice to make it more linear.\n",
        "        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),\n",
        "        # Globally defined calibration regularizer is applied to all features.\n",
        "        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),\n",
        "    ])\n",
        "# A CalibratedLattice premade model constructed from the given model config.\n",
        "lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nmc3TUIIGGoH"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vIjOQGD2Gp_Z"
      },
      "outputs": [],
      "source": [
        "lattice_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True)],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "lattice_model.fit(\n",
        "    heart_train_xs[:5],\n",
        "    heart_train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(lattice_model.evaluate(heart_test_xs[:5], heart_test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bx74CD4Cc4T3"
      },
      "source": [
        "## Calibrated Lattice Ensemble Model\n",
        "\n",
        "When the number of features is large, you can use an ensemble model, which creates multiple smaller lattices for subsets of the features and averages their output instead of creating just a single huge lattice. Ensemble lattice models are constructed using [tfl.configs.CalibratedLatticeEnsembleConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLatticeEnsembleConfig). A calibrated lattice ensemble model applies piecewise-linear and categorical calibration on the input feature, followed by an ensemble of lattice models and an optional output piecewise-linear calibration."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mbg4lsKqnEkV"
      },
      "source": [
        "### Explicit Lattice Ensemble Initialization\n",
        "\n",
        "If you already know which subsets of features you want to feed into your lattices, then you can explicitly set the lattices using feature names. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yu8Twg8mdJ18"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=heart_feature_configs,\n",
        "    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],\n",
        "              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],\n",
        "              ['restecg', 'age', 'sex']],\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\n",
        "    output_initialization=[-2.0, 2.0])\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    explicit_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(\n",
        "    explicit_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PJYR0i6MMDyh"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "capt98IOMHEm"
      },
      "outputs": [],
      "source": [
        "explicit_ensemble_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True)],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "explicit_ensemble_model.fit(\n",
        "    heart_train_xs,\n",
        "    heart_train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(explicit_ensemble_model.evaluate(heart_test_xs, heart_test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VnI70C9gdKQw"
      },
      "source": [
        "### Random Lattice Ensemble\n",
        "\n",
        "If you are not sure which subsets of features to feed into your lattices, another option is to use random subsets of features for each lattice. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7EhWrQaPIXj8"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=heart_feature_configs,\n",
        "    lattices='random',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\n",
        "    output_initialization=[-2.0, 2.0],\n",
        "    random_seed=42)\n",
        "# Now we must set the random lattice structure and construct the model.\n",
        "tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    random_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(\n",
        "    random_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sbxcIF0PJUDc"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w0YdCDyGJY1G"
      },
      "outputs": [],
      "source": [
        "random_ensemble_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True)],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "random_ensemble_model.fit(\n",
        "    heart_train_xs,\n",
        "    heart_train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(random_ensemble_model.evaluate(heart_test_xs, heart_test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZhJWe7fZIs4-"
      },
      "source": [
        "### RTL Layer Random Lattice Ensemble\n",
        "\n",
        "When using a random lattice ensemble, you can specify that the model use a single `tfl.layers.RTL` layer. We note that `tfl.layers.RTL` only supports monotonicity constraints and must have the same lattice size for all features and no per-feature regularization. Note that using a `tfl.layers.RTL` layer lets you scale to much larger ensembles than using separate `tfl.layers.Lattice` instances.\n",
        "\n",
        "This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0PC9oRFYJMF_"
      },
      "outputs": [],
      "source": [
        "# Make sure our feature configs have the same lattice size, no per-feature\n",
        "# regularization, and only monotonicity constraints.\n",
        "rtl_layer_feature_configs = copy.deepcopy(heart_feature_configs)\n",
        "for feature_config in rtl_layer_feature_configs:\n",
        "  feature_config.lattice_size = 2\n",
        "  feature_config.unimodality = 'none'\n",
        "  feature_config.reflects_trust_in = None\n",
        "  feature_config.dominates = None\n",
        "  feature_config.regularizer_configs = None\n",
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=rtl_layer_feature_configs,\n",
        "    lattices='rtl_layer',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\n",
        "    output_initialization=[-2.0, 2.0],\n",
        "    random_seed=42)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config. Note that we do not have to specify the lattices by calling\n",
        "# a helper function (like before with random) because the RTL Layer will take\n",
        "# care of that for us.\n",
        "rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    rtl_layer_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(\n",
        "    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yWdxZpS0JWag"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HQdkkWwqJW8p"
      },
      "outputs": [],
      "source": [
        "rtl_layer_ensemble_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True)],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "rtl_layer_ensemble_model.fit(\n",
        "    heart_train_xs,\n",
        "    heart_train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(rtl_layer_ensemble_model.evaluate(heart_test_xs, heart_test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A61VpAl8uOiT"
      },
      "source": [
        "### Crystals Lattice Ensemble\n",
        "\n",
        "Premade also provides a heuristic feature arrangement algorithm, called [Crystals](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices). To use the Crystals algorithm, first we train a prefitting model that estimates pairwise feature interactions. We then arrange the final ensemble such that features with more non-linear interactions are in the same lattices.\n",
        "\n",
        "the Premade Library offers helper functions for constructing the prefitting model configuration and extracting the crystals structure. Note that the prefitting model does not need to be fully trained, so a few epochs should be enough.\n",
        "\n",
        "This example creates a calibrated lattice ensemble model with 5 lattice and 3 features per lattice."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yT5eiknCu9sj"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combines non-linearly and averaged using multiple lattice layers.\n",
        "crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=heart_feature_configs,\n",
        "    lattices='crystals',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\n",
        "    output_initialization=[-2.0, 2.0],\n",
        "    random_seed=42)\n",
        "# Now that we have our model config, we can construct a prefitting model config.\n",
        "prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(\n",
        "    crystals_ensemble_model_config)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# prefitting model config.\n",
        "prefitting_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    prefitting_model_config)\n",
        "# We can compile and train our prefitting model as we like.\n",
        "prefitting_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "prefitting_model.fit(\n",
        "    heart_train_xs,\n",
        "    heart_train_ys,\n",
        "    epochs=PREFITTING_NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "# Now that we have our trained prefitting model, we can extract the crystals.\n",
        "tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,\n",
        "                                              prefitting_model_config,\n",
        "                                              prefitting_model)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    crystals_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "keras.utils.plot_model(\n",
        "    crystals_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRLU1z-216h8"
      },
      "source": [
        "As before, we compile, fit, and evaluate our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U73On3v91-Qq"
      },
      "outputs": [],
      "source": [
        "crystals_ensemble_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True)],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE))\n",
        "crystals_ensemble_model.fit(\n",
        "    heart_train_xs,\n",
        "    heart_train_ys,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(crystals_ensemble_model.evaluate(heart_test_xs, heart_test_ys))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "premade_models.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}


================================================
FILE: docs/tutorials/shape_constraints.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RKQpW0JqQQmY"
      },
      "source": [
        "# Shape Constraints with Tensorflow Lattice\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/shape_constraints\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/shape_constraints.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2plcL3iTVjsp"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial is an overview of the constraints and regularizers provided by the TensorFlow Lattice (TFL) library. Here we use TFL premade models on synthetic datasets, but note that everything in this tutorial can also be done with models constructed from TFL Keras layers.\n",
        "\n",
        "Before proceeding, make sure your runtime has all required packages installed (as imported in the code cells below)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install -U tensorflow tf-keras tensorflow-lattice pydot graphviz\n",
        "!pip install -U tensorflow_decision_forests"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "iY6awAl058TV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_lattice as tfl\n",
        "import tensorflow_decision_forests as tfdf\n",
        "\n",
        "from IPython.core.pylabtools import figsize\n",
        "import functools\n",
        "import logging\n",
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tempfile\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8dsfk2oNlakY"
      },
      "outputs": [],
      "source": [
        "# Use Keras 2.\n",
        "version_fn = getattr(tf.keras, \"version\", None)\n",
        "if version_fn and version_fn().startswith(\"3.\"):\n",
        "  import tf_keras as keras\n",
        "else:\n",
        "  keras = tf.keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7TmBk_IGgJF0"
      },
      "source": [
        "Default values used in this guide:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kQHPyPsPUF92"
      },
      "outputs": [],
      "source": [
        "NUM_EPOCHS = 1000\n",
        "BATCH_SIZE = 64\n",
        "LEARNING_RATE=0.01"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FjR7D8Ag3z0d"
      },
      "source": [
        "## Training Dataset for Ranking Restaurants"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a1YetzbdFOij"
      },
      "source": [
        "Imagine a simplified scenario where we want to determine whether or not users will click on a restaurant search result. The task is to predict the clickthrough rate (CTR) given input features:\n",
        "- Average rating (`avg_rating`): a numeric feature with values in the range [1,5].\n",
        "- Number of reviews (`num_reviews`): a numeric feature with values capped at 200, which we use as a measure of trendiness.\n",
        "- Dollar rating (`dollar_rating`): a categorical feature with string values in the set {\"D\", \"DD\", \"DDD\", \"DDDD\"}.\n",
        "\n",
        "Here we create a synthetic dataset where the true CTR is given by the formula:\n",
        "$$\n",
        "CTR = 1 / (1 + exp\\{\\mbox{b(dollar_rating)}-\\mbox{avg_rating}\\times log(\\mbox{num_reviews}) /4 \\})\n",
        "$$\n",
        "where $b(\\cdot)$ translates each `dollar_rating` to a baseline value:\n",
        "$$\n",
        "\\mbox{D}\\to 3,\\ \\mbox{DD}\\to 2,\\ \\mbox{DDD}\\to 4,\\ \\mbox{DDDD}\\to 4.5.\n",
        "$$\n",
        "\n",
        "This formula reflects typical user patterns. e.g. given everything else fixed, users prefer restaurants with higher star ratings, and \"\\\\$\\\\$\" restaurants will receive more clicks than \"\\\\$\", followed by \"\\\\$\\\\$\\\\$\" and \"\\\\$\\\\$\\\\$\\\\$\"."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mKovnyv1jATw"
      },
      "outputs": [],
      "source": [
        "dollar_ratings_vocab = [\"D\", \"DD\", \"DDD\", \"DDDD\"]\n",
        "def click_through_rate(avg_ratings, num_reviews, dollar_ratings):\n",
        "  dollar_rating_baseline = {\"D\": 3, \"DD\": 2, \"DDD\": 4, \"DDDD\": 4.5}\n",
        "  return 1 / (1 + np.exp(\n",
        "      np.array([dollar_rating_baseline[d] for d in dollar_ratings]) -\n",
        "      avg_ratings * np.log1p(num_reviews) / 4))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BPlgRdt6jAbP"
      },
      "source": [
        "Let's take a look at the contour plots of this CTR function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KC5qX_XKmc7g"
      },
      "outputs": [],
      "source": [
        "def color_bar():\n",
        "  bar = matplotlib.cm.ScalarMappable(\n",
        "      norm=matplotlib.colors.Normalize(0, 1, True),\n",
        "      cmap=\"viridis\",\n",
        "  )\n",
        "  bar.set_array([0, 1])\n",
        "  return bar\n",
        "\n",
        "\n",
        "def plot_fns(fns, res=25):\n",
        "  \"\"\"Generates contour plots for a list of (name, fn) functions.\"\"\"\n",
        "  num_reviews, avg_ratings = np.meshgrid(\n",
        "      np.linspace(0, 200, num=res),\n",
        "      np.linspace(1, 5, num=res),\n",
        "  )\n",
        "  figsize(13, 3.5 * len(fns))\n",
        "  fig, axes = plt.subplots(\n",
        "      len(fns), len(dollar_ratings_vocab), sharey=True, layout=\"constrained\"\n",
        "  )\n",
        "  axes = axes.flatten()\n",
        "  axes_index = 0\n",
        "  for fn_name, fn in fns:\n",
        "    for dollar_rating_split in dollar_ratings_vocab:\n",
        "      dollar_ratings = np.repeat(dollar_rating_split, res**2)\n",
        "      values = fn(avg_ratings.flatten(), num_reviews.flatten(), dollar_ratings)\n",
        "      title = \"{}: dollar_rating={}\".format(fn_name, dollar_rating_split)\n",
        "      subplot = axes[axes_index]\n",
        "      axes_index += 1\n",
        "      subplot.contourf(\n",
        "          avg_ratings,\n",
        "          num_reviews,\n",
        "          np.reshape(values, (res, res)),\n",
        "          vmin=0,\n",
        "          vmax=1,\n",
        "      )\n",
        "      subplot.title.set_text(title)\n",
        "      subplot.set(xlabel=\"Average Rating\")\n",
        "      subplot.set(ylabel=\"Number of Reviews\")\n",
        "      subplot.set(xlim=(1, 5))\n",
        "\n",
        "  if len(fns) \u003c= 2:\n",
        "    cax = fig.add_axes([\n",
        "        axes[-1].get_position().x1 + 0.11,\n",
        "        axes[-1].get_position().y0,\n",
        "        0.02,\n",
        "        0.8,\n",
        "    ])\n",
        "    _ = fig.colorbar(color_bar(), cax=cax)\n",
        "\n",
        "\n",
        "plot_fns([(\"CTR\", click_through_rate)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ol91olp3muNN"
      },
      "source": [
        "### Preparing Data\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H8BOshZS9xwn"
      },
      "source": [
        "We now need to create our synthetic datasets. We start by generating a simulated dataset of restaurants and their features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MhqcOPdTT_wj"
      },
      "outputs": [],
      "source": [
        "def sample_restaurants(n):\n",
        "  avg_ratings = np.random.uniform(1.0, 5.0, n)\n",
        "  num_reviews = np.round(np.exp(np.random.uniform(0.0, np.log(200), n)))\n",
        "  dollar_ratings = np.random.choice(dollar_ratings_vocab, n)\n",
        "  ctr_labels = click_through_rate(avg_ratings, num_reviews, dollar_ratings)\n",
        "  return avg_ratings, num_reviews, dollar_ratings, ctr_labels\n",
        "\n",
        "\n",
        "np.random.seed(42)\n",
        "avg_ratings, num_reviews, dollar_ratings, ctr_labels = sample_restaurants(2000)\n",
        "\n",
        "figsize(5, 5)\n",
        "fig, axs = plt.subplots(1, 1, sharey=False, layout=\"constrained\")\n",
        "\n",
        "for rating, marker in [(\"D\", \"o\"), (\"DD\", \"^\"), (\"DDD\", \"+\"), (\"DDDD\", \"x\")]:\n",
        "  plt.scatter(\n",
        "      x=avg_ratings[np.where(dollar_ratings == rating)],\n",
        "      y=num_reviews[np.where(dollar_ratings == rating)],\n",
        "      c=ctr_labels[np.where(dollar_ratings == rating)],\n",
        "      vmin=0,\n",
        "      vmax=1,\n",
        "      marker=marker,\n",
        "      label=rating)\n",
        "plt.xlabel(\"Average Rating\")\n",
        "plt.ylabel(\"Number of Reviews\")\n",
        "plt.legend()\n",
        "plt.xlim((1, 5))\n",
        "plt.title(\"Distribution of restaurants\")\n",
        "_ = fig.colorbar(color_bar(), cax=fig.add_axes([1.05, 0.1, 0.05, 0.85]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tRetsfLv_JSR"
      },
      "source": [
        "Let's produce the training, validation and testing datasets. When a restaurant is viewed in the search results, we can record user's engagement (click or no click) as a sample point.\n",
        "\n",
        "In practice, users often do not go through all search results. This means that users will likely only see restaurants already considered \"good\" by the current ranking model in use. As a result, \"good\" restaurants are more frequently impressed and over-represented in the training datasets. When using more features, the training dataset can have large gaps in \"bad\" parts of the feature space.\n",
        "\n",
        "When the model is used for ranking, it is often evaluated on all relevant results with a more uniform distribution that is not well-represented by the training dataset. A flexible and complicated model might fail in this case due to overfitting the over-represented data points and thus lack generalizability. We handle this issue by applying domain knowledge to add *shape constraints* that guide the model to make reasonable predictions when it cannot pick them up from the training dataset.\n",
        "\n",
        "In this example, the training dataset mostly consists of user interactions with good and popular restaurants. The testing dataset has a uniform distribution to simulate the evaluation setting discussed above. Note that such testing dataset will not be available in a real problem setting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jS6WOtXQ8jwX"
      },
      "outputs": [],
      "source": [
        "def sample_dataset(n, testing_set):\n",
        "  (avg_ratings, num_reviews, dollar_ratings, ctr_labels) = sample_restaurants(n)\n",
        "  if testing_set:\n",
        "    # Testing has a more uniform distribution over all restaurants.\n",
        "    num_views = np.random.poisson(lam=3, size=n)\n",
        "  else:\n",
        "    # Training/validation datasets have more views on popular restaurants.\n",
        "    num_views = np.random.poisson(lam=ctr_labels * num_reviews / 50.0, size=n)\n",
        "\n",
        "  return pd.DataFrame({\n",
        "      \"avg_rating\": np.repeat(avg_ratings, num_views),\n",
        "      \"num_reviews\": np.repeat(num_reviews, num_views),\n",
        "      \"dollar_rating\": np.repeat(dollar_ratings, num_views),\n",
        "      \"clicked\": np.random.binomial(n=1, p=np.repeat(ctr_labels, num_views)),\n",
        "  })\n",
        "\n",
        "\n",
        "# Generate datasets.\n",
        "np.random.seed(42)\n",
        "data_train = sample_dataset(500, testing_set=False)\n",
        "data_val = sample_dataset(500, testing_set=False)\n",
        "data_test = sample_dataset(500, testing_set=True)\n",
        "\n",
        "ds_train = tfdf.keras.pd_dataframe_to_tf_dataset(\n",
        "    data_train, label=\"clicked\", batch_size=BATCH_SIZE\n",
        ")\n",
        "ds_val = tfdf.keras.pd_dataframe_to_tf_dataset(\n",
        "    data_val, label=\"clicked\", batch_size=BATCH_SIZE\n",
        ")\n",
        "ds_test = tfdf.keras.pd_dataframe_to_tf_dataset(\n",
        "    data_test, label=\"clicked\", batch_size=BATCH_SIZE\n",
        ")\n",
        "\n",
        "# feature_analysis_data is used to find quantiles of featurse.\n",
        "feature_analysis_data = data_train.copy()\n",
        "feature_analysis_data[\"dollar_rating\"] = feature_analysis_data[\n",
        "    \"dollar_rating\"\n",
        "].map({v: i for i, v in enumerate(dollar_ratings_vocab)})\n",
        "feature_analysis_data = dict(feature_analysis_data)\n",
        "\n",
        "# Plotting dataset densities.\n",
        "figsize(12, 5)\n",
        "fig, axs = plt.subplots(1, 2, sharey=False, tight_layout=False)\n",
        "for ax, data, title in [\n",
        "    (axs[0], data_train, \"training\"),\n",
        "    (axs[1], data_test, \"testing\"),\n",
        "]:\n",
        "  _, _, _, density = ax.hist2d(\n",
        "      x=data[\"avg_rating\"],\n",
        "      y=data[\"num_reviews\"],\n",
        "      bins=(np.linspace(1, 5, num=21), np.linspace(0, 200, num=21)),\n",
        "      cmap=\"Blues\",\n",
        "  )\n",
        "  ax.set(xlim=(1, 5))\n",
        "  ax.set(ylim=(0, 200))\n",
        "  ax.set(xlabel=\"Average Rating\")\n",
        "  ax.set(ylabel=\"Number of Reviews\")\n",
        "  ax.title.set_text(\"Density of {} examples\".format(title))\n",
        "  _ = fig.colorbar(density, ax=ax)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qoTrw3FZqvPK"
      },
      "source": [
        "## Fitting Gradient Boosted Trees"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZklNowexE3wB"
      },
      "source": [
        "We first create a few auxillary functions for plotting and calculating validation and test metrics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3BqGqScQzlYf"
      },
      "outputs": [],
      "source": [
        "def pred_fn(model, from_logits, avg_ratings, num_reviews, dollar_rating):\n",
        "  preds = model.predict(\n",
        "      tf.data.Dataset.from_tensor_slices({\n",
        "          \"avg_rating\": avg_ratings,\n",
        "          \"num_reviews\": num_reviews,\n",
        "          \"dollar_rating\": dollar_rating,\n",
        "      }).batch(1),\n",
        "      verbose=0,\n",
        "  )\n",
        "  if from_logits:\n",
        "    preds = tf.math.sigmoid(preds)\n",
        "  return preds\n",
        "\n",
        "\n",
        "def analyze_model(models, from_logits=False, print_metrics=True):\n",
        "  pred_fns = []\n",
        "  for model, name in models:\n",
        "    if print_metrics:\n",
        "      metric = model.evaluate(ds_val, return_dict=True, verbose=0)\n",
        "      print(\"Validation AUC: {}\".format(metric[\"auc\"]))\n",
        "      metric = model.evaluate(ds_test, return_dict=True, verbose=0)\n",
        "      print(\"Testing AUC: {}\".format(metric[\"auc\"]))\n",
        "\n",
        "    pred_fns.append(\n",
        "        (\"{} pCTR\".format(name), functools.partial(pred_fn, model, from_logits))\n",
        "    )\n",
        "\n",
        "  pred_fns.append((\"CTR\", click_through_rate))\n",
        "  plot_fns(pred_fns)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JVef4f8yUUbs"
      },
      "source": [
        "We can fit TensorFlow gradient boosted decision trees on the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DnPYlRAo2mnQ"
      },
      "outputs": [],
      "source": [
        "gbt_model = tfdf.keras.GradientBoostedTreesModel(\n",
        "    features=[\n",
        "        tfdf.keras.FeatureUsage(name=\"num_reviews\"),\n",
        "        tfdf.keras.FeatureUsage(name=\"avg_rating\"),\n",
        "        tfdf.keras.FeatureUsage(name=\"dollar_rating\"),\n",
        "    ],\n",
        "    exclude_non_specified_features=True,\n",
        "    num_threads=1,\n",
        "    num_trees=32,\n",
        "    max_depth=6,\n",
        "    min_examples=10,\n",
        "    growing_strategy=\"BEST_FIRST_GLOBAL\",\n",
        "    random_seed=42,\n",
        "    temp_directory=tempfile.mkdtemp(),\n",
        ")\n",
        "gbt_model.compile(metrics=[keras.metrics.AUC(name=\"auc\")])\n",
        "gbt_model.fit(ds_train, validation_data=ds_val, verbose=0)\n",
        "analyze_model([(gbt_model, \"GBT\")])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYZtd6YvsNdn"
      },
      "source": [
        "Even though the model has captured the general shape of the true CTR and has decent validation metrics, it has counter-intuitive behavior in several parts of the input space: the estimated CTR decreases as the average rating or number of reviews increase. This is due to a lack of sample points in areas not well-covered by the training dataset. The model simply has no way to deduce the correct behaviour solely from the data.\n",
        "\n",
        "To solve this issue, we enforce the shape constraint that the model must output values monotonically increasing with respect to both the average rating and the number of reviews. We will later see how to implement this in TFL.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uf7WqGooFiEp"
      },
      "source": [
        "## Fitting a DNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_s2aT3x0E_tF"
      },
      "source": [
        "We can repeat the same steps with a DNN classifier. We can observe a similar pattern: not having enough sample points with small number of reviews results in nonsensical extrapolation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WKZzCY-UkZX-"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "inputs = {\n",
        "    \"num_reviews\": keras.Input(shape=(1,), dtype=tf.float32),\n",
        "    \"avg_rating\": keras.Input(shape=(1), dtype=tf.float32),\n",
        "    \"dollar_rating\": keras.Input(shape=(1), dtype=tf.string),\n",
        "}\n",
        "inputs_flat = keras.layers.Concatenate()([\n",
        "    inputs[\"num_reviews\"],\n",
        "    inputs[\"avg_rating\"],\n",
        "    keras.layers.StringLookup(\n",
        "        vocabulary=dollar_ratings_vocab,\n",
        "        num_oov_indices=0,\n",
        "        output_mode=\"one_hot\",\n",
        "    )(inputs[\"dollar_rating\"]),\n",
        "])\n",
        "dense_layers = keras.Sequential(\n",
        "    [\n",
        "        keras.layers.Dense(16, activation=\"relu\"),\n",
        "        keras.layers.Dense(16, activation=\"relu\"),\n",
        "        keras.layers.Dense(1, activation=None),\n",
        "    ],\n",
        "    name=\"dense_layers\",\n",
        ")\n",
        "dnn_model = keras.Model(inputs=inputs, outputs=dense_layers(inputs_flat))\n",
        "keras.utils.plot_model(\n",
        "    dnn_model, expand_nested=True, show_layer_names=False, rankdir=\"LR\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6zFqu6wf1I30"
      },
      "outputs": [],
      "source": [
        "dnn_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True, name=\"auc\")],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE),\n",
        ")\n",
        "dnn_model.fit(ds_train, epochs=200, verbose=0)\n",
        "analyze_model([(dnn_model, \"DNN\")], from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Avkw-okw7JL"
      },
      "source": [
        "## Shape Constraints"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ExyethCFBrP"
      },
      "source": [
        "TensorFlow Lattice (TFL) is focused on enforcing shape constraints to safeguard model behavior beyond the training data. These shape constraints are applied to TFL Keras layers. Their details can be found in [our JMLR paper](http://jmlr.org/papers/volume17/15-243/15-243.pdf).\n",
        "\n",
        "In this tutorial we use TF premade models to cover various shape constraints, but note that all these steps can be done with models created from TFL Keras layers.\n",
        "\n",
        "Using TFL premade models also requires:\n",
        "- a *model config*: defining the model architecture and per-feature shape constraints and regularizers.\n",
        "- a *feature analysis dataset*: a dataset used for TFL initialization (feature quantile calcuation).\n",
        "\n",
        "For a more thorough description, please refer to the premade models or the API docs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "anyCM4sCpOSo"
      },
      "source": [
        "### Monotonicity\n",
        "We first address the monotonicity concerns by adding monotonicity shape constraints to the continuous features. We use a calibrated lattice model with added output calibration: each feature is calibrated using categorical or piecewise-linear calibrators, then fed into a lattice model, followed by an output piecewise-linear calibrator.\n",
        "\n",
        "To instruct TFL to enforce shape constraints, we specify the constraints in the *feature configs*. The following code shows how we can require the output to be monotonically increasing with respect to both `num_reviews` and `avg_rating` by setting `monotonicity=\"increasing\"`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFlkZs5RgFcP"
      },
      "outputs": [],
      "source": [
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            vocabulary_list=dollar_ratings_vocab,\n",
        "            num_buckets=len(dollar_ratings_vocab),\n",
        "        ),\n",
        "    ],\n",
        "    output_calibration=True,\n",
        "    output_initialization=np.linspace(-2, 2, num=5),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GOlzuyQsGre5"
      },
      "source": [
        "We now use the `feature_analysis_data` to find and set the quantile values for the input features. These values can be pre-calculated and set explicitly in the feature config depending on the training pipeline."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f-bTmfBnghuX"
      },
      "outputs": [],
      "source": [
        "feature_analysis_data = data_train.copy()\n",
        "feature_analysis_data[\"dollar_rating\"] = feature_analysis_data[\n",
        "    \"dollar_rating\"\n",
        "].map({v: i for i, v in enumerate(dollar_ratings_vocab)})\n",
        "feature_analysis_data = dict(feature_analysis_data)\n",
        "\n",
        "feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs, features=feature_analysis_data\n",
        ")\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FCm1lOjmwur_"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "inputs = {\n",
        "    \"num_reviews\": keras.Input(shape=(1,), dtype=tf.float32),\n",
        "    \"avg_rating\": keras.Input(shape=(1), dtype=tf.float32),\n",
        "    \"dollar_rating\": keras.Input(shape=(1), dtype=tf.string),\n",
        "}\n",
        "ordered_inputs = [\n",
        "    inputs[\"num_reviews\"],\n",
        "    inputs[\"avg_rating\"],\n",
        "    keras.layers.StringLookup(\n",
        "        vocabulary=dollar_ratings_vocab,\n",
        "        num_oov_indices=0,\n",
        "        output_mode=\"int\",\n",
        "    )(inputs[\"dollar_rating\"]),\n",
        "]\n",
        "outputs = tfl.premade.CalibratedLattice(\n",
        "    model_config=model_config, name=\"CalibratedLattice\"\n",
        ")(ordered_inputs)\n",
        "tfl_model_0 = keras.Model(inputs=inputs, outputs=outputs)\n",
        "\n",
        "keras.utils.plot_model(\n",
        "    tfl_model_0, expand_nested=True, show_layer_names=False, rankdir=\"LR\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ubNRBCWW5wQ9"
      },
      "source": [
        "Using a `CalibratedLatticeConfig` creates a premade classifier that first applies a *calibrator* to each input (a piece-wise linear function for numeric features) followed by a *lattice* layer to non-linearly fuse the calibrated features. We have also enabled output piece-wise linear calibration.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Am1OwtzzU7no"
      },
      "outputs": [],
      "source": [
        "tfl_model_0.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True, name=\"auc\")],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE),\n",
        ")\n",
        "tfl_model_0.fit(ds_train, epochs=100, verbose=0)\n",
        "analyze_model([(tfl_model_0, \"TFL0\")], from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7vZ5fShXs504"
      },
      "source": [
        "With the constraints added, the estimated CTR will always increase as the average rating increases or the number of reviews increases. This is done by making sure that the calibrators and the lattice are monotonic."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pSUd6aFlpYz4"
      },
      "source": [
        "### Partial Monotonicity for Categorical Calibration\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CnPiqf4rq6kJ"
      },
      "source": [
        "To use constraints on the third feature, `dollar_rating`, we should recall that categorical features require a slightly different treatment in TFL. Here we enforce the partial monotonicity constraint that outputs for \"DD\" restaurants should be larger than \"D\" restaurants when all other inputs are fixed. This is done using the `monotonicity` setting in the feature config. We also need to use `tfl.premade_lib.set_categorical_monotonicities` to convert the constrains specified in string values into the numerical format understood by the library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FH2ItfsTsE3S"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            vocabulary_list=dollar_ratings_vocab,\n",
        "            num_buckets=len(dollar_ratings_vocab),\n",
        "            monotonicity=[(\"D\", \"DD\")],\n",
        "        ),\n",
        "    ],\n",
        "    output_calibration=True,\n",
        "    output_initialization=np.linspace(-2, 2, num=5),\n",
        ")\n",
        "\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False,\n",
        ")\n",
        "tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\n",
        "\n",
        "outputs = tfl.premade.CalibratedLattice(\n",
        "    model_config=model_config, name=\"CalibratedLattice\"\n",
        ")(ordered_inputs)\n",
        "tfl_model_1 = keras.Model(inputs=inputs, outputs=outputs)\n",
        "tfl_model_1.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True, name=\"auc\")],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE),\n",
        ")\n",
        "tfl_model_1.fit(ds_train, epochs=100, verbose=0)\n",
        "analyze_model([(tfl_model_1, \"TFL1\")], from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gdIzhYL79_Pp"
      },
      "source": [
        "Here we also plot the predicted CTR of this model conditioned on `dollar_rating`. Notice that all the constraints we required are fulfilled in each of the slices."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J6CP2Ovapiu3"
      },
      "source": [
        "### 2D Shape Constraint: Trust\n",
        "A 5-star rating for a restaurant with only one or two reviews is likely an unreliable rating (the restaurant might not actually be good), whereas a 4-star rating for a restaurant with hundreds of reviews is much more reliable (the restaurant is likely good in this case). We can see that the number of reviews of a restaurant affects how much trust we place in its average rating.\n",
        "\n",
        "We can exercise TFL trust constraints to inform the model that the larger (or smaller) value of one feature indicates more reliance or trust of another feature. This is done by setting `reflects_trust_in` configuration in the feature config."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OA14j0erm6TJ"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "            # Larger num_reviews indicating more trust in avg_rating.\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"\n",
        "                ),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            vocabulary_list=dollar_ratings_vocab,\n",
        "            num_buckets=len(dollar_ratings_vocab),\n",
        "            monotonicity=[(\"D\", \"DD\")],\n",
        "        ),\n",
        "    ],\n",
        "    output_calibration=True,\n",
        "    output_initialization=np.linspace(-2, 2, num=5),\n",
        ")\n",
        "\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False,\n",
        ")\n",
        "tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\n",
        "\n",
        "outputs = tfl.premade.CalibratedLattice(\n",
        "    model_config=model_config, name=\"CalibratedLattice\"\n",
        ")(ordered_inputs)\n",
        "tfl_model_2 = keras.Model(inputs=inputs, outputs=outputs)\n",
        "tfl_model_2.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True, name=\"auc\")],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE),\n",
        ")\n",
        "tfl_model_2.fit(ds_train, epochs=100, verbose=0)\n",
        "analyze_model([(tfl_model_2, \"TFL2\")], from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "puvP9X8XxyRV"
      },
      "source": [
        "The following plot presents the trained lattice function. Due to the trust constraint, we expect that larger values of calibrated `num_reviews` would force higher slope with respect to calibrated `avg_rating`, resulting in a more significant move in the lattice output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "RounEQebxxnA"
      },
      "outputs": [],
      "source": [
        "lattice_params = tfl_model_2.layers[-1].layers[-2].weights[0].numpy()\n",
        "lat_mesh_x, lat_mesh_y = np.meshgrid(\n",
        "    np.linspace(0, 1, num=3),\n",
        "    np.linspace(0, 1, num=3),\n",
        ")\n",
        "lat_mesh_z = np.reshape(np.asarray(lattice_params[0::3]), (3, 3))\n",
        "\n",
        "figure = plt.figure(figsize=(6, 6))\n",
        "axes = figure.add_subplot(projection=\"3d\")\n",
        "axes.plot_wireframe(lat_mesh_x, lat_mesh_y, lat_mesh_z, color=\"dodgerblue\")\n",
        "plt.legend([\"Lattice Lookup\"])\n",
        "plt.title(\"Trust\")\n",
        "plt.xlabel(\"Calibrated avg_rating\")\n",
        "plt.ylabel(\"Calibrated num_reviews\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RfniRZCHIvfK"
      },
      "source": [
        "### Diminishing Returns\n",
        "[Diminishing returns](https://en.wikipedia.org/wiki/Diminishing_returns) means that the marginal gain of increasing a certain feature value will decrease as we increase the value. In our case we expect that the `num_reviews` feature follows this pattern, so we can configure its calibrator accordingly. Notice that we can decompose diminishing returns into two sufficient conditions:\n",
        "\n",
        "- the calibrator is monotonicially increasing, and\n",
        "- the calibrator is concave (setting `pwl_calibration_convexity=\"concave\"`).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XQrM9BskY-wx"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"\n",
        "                ),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            vocabulary_list=dollar_ratings_vocab,\n",
        "            num_buckets=len(dollar_ratings_vocab),\n",
        "            monotonicity=[(\"D\", \"DD\")],\n",
        "        ),\n",
        "    ],\n",
        "    output_calibration=True,\n",
        "    output_initialization=np.linspace(-2, 2, num=5),\n",
        ")\n",
        "\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False,\n",
        ")\n",
        "tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\n",
        "\n",
        "outputs = tfl.premade.CalibratedLattice(\n",
        "    model_config=model_config, name=\"CalibratedLattice\"\n",
        ")(ordered_inputs)\n",
        "tfl_model_3 = keras.Model(inputs=inputs, outputs=outputs)\n",
        "tfl_model_3.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True, name=\"auc\")],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE),\n",
        ")\n",
        "tfl_model_3.fit(\n",
        "    ds_train,\n",
        "    epochs=100,\n",
        "    verbose=0\n",
        ")\n",
        "analyze_model([(tfl_model_3, \"TFL3\")], from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSmzHkPUo9u5"
      },
      "source": [
        "Notice how the testing metric improves by adding the concavity constraint. The prediction plot also better resembles the ground truth."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SKe3UHX6pUjw"
      },
      "source": [
        "### Smoothing Calibrators\n",
        "We notice in the prediction curves above that even though the output is monotonic in specified features, the changes in the slopes are abrupt and hard to interpret. That suggests we might want to consider smoothing this calibrator using a regularizer setup in the `regularizer_configs`.\n",
        "\n",
        "Here we apply a `hessian` regularizer to make the calibration more linear. You can also use the `laplacian` regularizer to flatten the calibrator and the `wrinkle` regularizer to reduce changes in the curvature.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CxcCNxhkqC7u"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_hessian\", l2=0.5),\n",
        "            ],\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"\n",
        "                ),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=3,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=32,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_hessian\", l2=0.5),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            vocabulary_list=dollar_ratings_vocab,\n",
        "            num_buckets=len(dollar_ratings_vocab),\n",
        "            monotonicity=[(\"D\", \"DD\")],\n",
        "        ),\n",
        "    ],\n",
        "    output_calibration=True,\n",
        "    output_initialization=np.linspace(-2, 2, num=5),\n",
        "    regularizer_configs=[\n",
        "        tfl.configs.RegularizerConfig(name=\"calib_hessian\", l2=0.1),\n",
        "    ],\n",
        ")\n",
        "\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False,\n",
        ")\n",
        "tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\n",
        "\n",
        "outputs = tfl.premade.CalibratedLattice(\n",
        "    model_config=model_config, name=\"CalibratedLattice\"\n",
        ")(ordered_inputs)\n",
        "tfl_model_4 = keras.Model(inputs=inputs, outputs=outputs)\n",
        "tfl_model_4.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.AUC(from_logits=True, name=\"auc\")],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATE),\n",
        ")\n",
        "tfl_model_4.fit(ds_train, epochs=100, verbose=0)\n",
        "analyze_model([(tfl_model_4, \"TFL4\")], from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HHpp4goLvuPi"
      },
      "source": [
        "The calibrators are now smooth, and the overall estimated CTR better matches the ground truth. This is reflected both in the testing metric and in the contour plots."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TLOGDrYY0hH7"
      },
      "source": [
        "Here you can see the results of each step as we added domain-specific constraints and regularizers to the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nUEuihX815ix"
      },
      "outputs": [],
      "source": [
        "analyze_model(\n",
        "    [\n",
        "        (tfl_model_0, \"TFL0\"),\n",
        "        (tfl_model_1, \"TFL1\"),\n",
        "        (tfl_model_2, \"TFL2\"),\n",
        "        (tfl_model_3, \"TFL3\"),\n",
        "        (tfl_model_4, \"TFL4\"),\n",
        "    ],\n",
        "    from_logits=True,\n",
        "    print_metrics=False,\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "shape_constraints.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}


================================================
FILE: docs/tutorials/shape_constraints_for_ethics.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R2AxpObRncMd"
      },
      "source": [
        "***Copyright 2020 The TensorFlow Authors.***"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "gQ5Kfh1YnkFS"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uc0VwsT5nvQi"
      },
      "source": [
        "# Shape Constraints for Ethics with Tensorflow Lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gqJQZdvfn32j"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/shape_constraints_for_ethics\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints_for_ethics.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints_for_ethics.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/shape_constraints_for_ethics.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YFZbuZMAoBny"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial demonstrates how the TensorFlow Lattice (TFL) library can be used\n",
        "to train models that behave *responsibly*, and do not violate certain\n",
        "assumptions that are *ethical* or *fair*. In particular, we will focus on using monotonicity constraints to avoid *unfair penalization* of certain attributes. This tutorial includes demonstrations\n",
        "of the experiments from the paper\n",
        "[*Deontological Ethics By Monotonicity Shape Constraints*](https://arxiv.org/abs/2001.11990)\n",
        "by Serena Wang and Maya Gupta, published at\n",
        "[AISTATS 2020](https://www.aistats.org/).\n",
        "\n",
        "We will use TFL premade models on public datasets, but note that\n",
        "everything in this tutorial can also be done with models constructed from TFL\n",
        "Keras layers.\n",
        "\n",
        "Before proceeding, make sure your runtime has all required packages installed\n",
        "(as imported in the code cells below)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o4L76T-NpgCS"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6FvmHcqbpkL7"
      },
      "source": [
        "Installing TF Lattice package:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f91yvUt_peYs"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install -U tensorflow tf-keras tensorflow-lattice seaborn pydot graphviz\n",
        "!pip install -U tensorflow_decision_forests"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6TDoQsvSpmfx"
      },
      "source": [
        "Importing required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KGt0pm0b1O5X"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_lattice as tfl\n",
        "import tensorflow_decision_forests as tfdf\n",
        "\n",
        "import logging\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import os\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "from sklearn.model_selection import train_test_split\n",
        "import sys\n",
        "import tempfile\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "csVitiM20zAY"
      },
      "outputs": [],
      "source": [
        "# Use Keras 2.\n",
        "version_fn = getattr(tf.keras, \"version\", None)\n",
        "if version_fn and version_fn().startswith(\"3.\"):\n",
        "  import tf_keras as keras\n",
        "else:\n",
        "  keras = tf.keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DFN6GOcBAqzv"
      },
      "source": [
        "Default values used in this tutorial:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9uqMM2joAnoW"
      },
      "outputs": [],
      "source": [
        "# Default number of training epochs, batch sizes and learning rate.\n",
        "NUM_EPOCHS = 256\n",
        "BATCH_SIZE = 256\n",
        "LEARNING_RATES = 0.01\n",
        "# Directory containing dataset files.\n",
        "DATA_DIR = 'https://raw.githubusercontent.com/serenalwang/shape_constraints_for_ethics/master'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OZJQfJvY3ibC"
      },
      "source": [
        "# Case study #1: Law school admissions\n",
        "\n",
        "In the first part of this tutorial, we will consider a case study using the Law\n",
        "School Admissions dataset from the Law School Admissions Council (LSAC). We will\n",
        "train a classifier to predict whether or not a student will pass the bar using\n",
        "two features: the student's LSAT score and undergraduate GPA.\n",
        "\n",
        "Suppose that the classifier’s score was used to guide law school admissions or\n",
        "scholarships. According to merit-based social norms, we would expect that\n",
        "students with higher GPA and higher LSAT score should receive a higher score\n",
        "from the classifier. However, we will observe that it is easy for models to\n",
        "violate these intuitive norms, and sometimes penalize people for having a higher\n",
        "GPA or LSAT score.\n",
        "\n",
        "To address this *unfair penalization* problem, we can impose monotonicity\n",
        "constraints so that a model never penalizes higher GPA or higher LSAT score, all\n",
        "else equal. In this tutorial, we will show how to impose those monotonicity\n",
        "constraints using TFL."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vJES8lYT1fHN"
      },
      "source": [
        "## Load Law School Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cl89ZOsQ14An"
      },
      "outputs": [],
      "source": [
        "# Load data file.\n",
        "law_file_name = 'lsac.csv'\n",
        "law_file_path = os.path.join(DATA_DIR, law_file_name)\n",
        "raw_law_df = pd.read_csv(law_file_path, delimiter=',')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RCpTYCNjqOsC"
      },
      "source": [
        "Preprocess dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jdY5rtLs4xQK"
      },
      "outputs": [],
      "source": [
        "# Define label column name.\n",
        "LAW_LABEL = 'pass_bar'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1t1Hd8gu6Uat"
      },
      "outputs": [],
      "source": [
        "def preprocess_law_data(input_df):\n",
        "  # Drop rows with where the label or features of interest are missing.\n",
        "  output_df = input_df[~input_df[LAW_LABEL].isna() \u0026 ~input_df['ugpa'].isna() \u0026\n",
        "                       (input_df['ugpa'] \u003e 0) \u0026 ~input_df['lsat'].isna()]\n",
        "  return output_df\n",
        "\n",
        "\n",
        "law_df = preprocess_law_data(raw_law_df)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YhvSKr9SCrHP"
      },
      "source": [
        "### Split data into train/validation/test sets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gQKkIGD-CvGD"
      },
      "outputs": [],
      "source": [
        "def split_dataset(input_df, random_state=888):\n",
        "  \"\"\"Splits an input dataset into train, val, and test sets.\"\"\"\n",
        "  train_df, test_val_df = train_test_split(\n",
        "      input_df, test_size=0.3, random_state=random_state\n",
        "  )\n",
        "  val_df, test_df = train_test_split(\n",
        "      test_val_df, test_size=0.66, random_state=random_state\n",
        "  )\n",
        "  return train_df, val_df, test_df\n",
        "\n",
        "\n",
        "dataframes = {}\n",
        "datasets = {}\n",
        "\n",
        "(dataframes['law_train'], dataframes['law_val'], dataframes['law_test']) = (\n",
        "    split_dataset(law_df)\n",
        ")\n",
        "\n",
        "for df_name, df in dataframes.items():\n",
        "  datasets[df_name] = tf.data.Dataset.from_tensor_slices(\n",
        "      ((df[['ugpa']], df[['lsat']]), df[['pass_bar']])\n",
        "  ).batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zObwzY7f3aLy"
      },
      "source": [
        "### Visualize data distribution\n",
        "\n",
        "First we will visualize the distribution of the data. We will plot the GPA and\n",
        "LSAT scores for all students that passed the bar and also for all students that\n",
        "did not pass the bar."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dRAZB5cLORUG"
      },
      "outputs": [],
      "source": [
        "def plot_dataset_contour(input_df, title):\n",
        "  plt.rcParams['font.family'] = ['serif']\n",
        "  g = sns.jointplot(\n",
        "      x='ugpa',\n",
        "      y='lsat',\n",
        "      data=input_df,\n",
        "      kind='kde',\n",
        "      xlim=[1.4, 4],\n",
        "      ylim=[0, 50])\n",
        "  g.plot_joint(plt.scatter, c='b', s=10, linewidth=1, marker='+')\n",
        "  g.ax_joint.collections[0].set_alpha(0)\n",
        "  g.set_axis_labels('Undergraduate GPA', 'LSAT score', fontsize=14)\n",
        "  g.fig.suptitle(title, fontsize=14)\n",
        "  # Adust plot so that the title fits.\n",
        "  plt.subplots_adjust(top=0.9)\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "feovlsWPQhVG"
      },
      "outputs": [],
      "source": [
        "law_df_pos = law_df[law_df[LAW_LABEL] == 1]\n",
        "plot_dataset_contour(\n",
        "    law_df_pos, title='Distribution of students that passed the bar')\n",
        "law_df_neg = law_df[law_df[LAW_LABEL] == 0]\n",
        "plot_dataset_contour(\n",
        "    law_df_neg, title='Distribution of students that failed the bar')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6grrFEMPfPjk"
      },
      "source": [
        "## Train calibrated lattice model to predict bar exam passage\n",
        "\n",
        "Next, we will train a *calibrated lattice model* from TFL to predict whether or\n",
        "not a student will pass the bar. The two input features will be LSAT score and\n",
        "undergraduate GPA, and the training label will be whether the student passed the\n",
        "bar.\n",
        "\n",
        "We will first train a calibrated lattice model without any constraints. Then, we\n",
        "will train a calibrated lattice model with monotonicity constraints and observe\n",
        "the difference in the model output and accuracy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HSfAwgiO_6YA"
      },
      "source": [
        "### Helper functions for visualization of trained model outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aw28Xc7IS6vR"
      },
      "outputs": [],
      "source": [
        "def plot_model_contour(model, from_logits=False, num_keypoints=20):\n",
        "  x = np.linspace(min(law_df['ugpa']), max(law_df['ugpa']), num_keypoints)\n",
        "  y = np.linspace(min(law_df['lsat']), max(law_df['lsat']), num_keypoints)\n",
        "\n",
        "  x_grid, y_grid = np.meshgrid(x, y)\n",
        "\n",
        "  positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\n",
        "  plot_df = pd.DataFrame(positions.T, columns=['ugpa', 'lsat'])\n",
        "  plot_df[LAW_LABEL] = np.ones(len(plot_df))\n",
        "  predictions = model.predict((plot_df[['ugpa']], plot_df[['lsat']]))\n",
        "  if from_logits:\n",
        "    predictions = tf.math.sigmoid(predictions)\n",
        "  grid_predictions = np.reshape(predictions, x_grid.shape)\n",
        "\n",
        "  plt.rcParams['font.family'] = ['serif']\n",
        "  plt.contour(\n",
        "      x_grid,\n",
        "      y_grid,\n",
        "      grid_predictions,\n",
        "      colors=('k',),\n",
        "      levels=np.linspace(0, 1, 11),\n",
        "  )\n",
        "  plt.contourf(\n",
        "      x_grid,\n",
        "      y_grid,\n",
        "      grid_predictions,\n",
        "      cmap=plt.cm.bone,\n",
        "      levels=np.linspace(0, 1, 11),\n",
        "  )\n",
        "  plt.xticks(fontsize=20)\n",
        "  plt.yticks(fontsize=20)\n",
        "\n",
        "  cbar = plt.colorbar()\n",
        "  cbar.ax.set_ylabel('Model score', fontsize=20)\n",
        "  cbar.ax.tick_params(labelsize=20)\n",
        "\n",
        "  plt.xlabel('Undergraduate GPA', fontsize=20)\n",
        "  plt.ylabel('LSAT score', fontsize=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fAMSCaRHIn1w"
      },
      "source": [
        "## Train unconstrained (non-monotonic) calibrated lattice model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mK7RWDJ5ugdd"
      },
      "source": [
        "We create a TFL premade model using a '`CalibratedLatticeConfig`. This model is a calibrated lattice model with an output calibration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J16TOicHQ1sM"
      },
      "outputs": [],
      "source": [
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name='ugpa',\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=16,\n",
        "            monotonicity=0,\n",
        "            pwl_calibration_always_monotonic=False,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name='lsat',\n",
        "            lattice_size=3,\n",
        "            pwl_calibration_num_keypoints=16,\n",
        "            monotonicity=0,\n",
        "            pwl_calibration_always_monotonic=False,\n",
        "        ),\n",
        "    ],\n",
        "    output_calibration=True,\n",
        "    output_initialization=np.linspace(-2, 2, num=8),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jt1Rm6qCuuat"
      },
      "source": [
        "We calculate and populate feature quantiles in the feature configs using the `premade_lib` API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eSELqBdURE0F"
      },
      "outputs": [],
      "source": [
        "feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    features=dataframes['law_train'][['ugpa', 'lsat', 'pass_bar']],\n",
        ")\n",
        "tfl.premade_lib.set_feature_keypoints(\n",
        "    feature_configs=model_config.feature_configs,\n",
        "    feature_keypoints=feature_keypoints,\n",
        "    add_missing_feature_configs=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ahV2Sn0Xz1aO"
      },
      "outputs": [],
      "source": [
        "nomon_lattice_model = tfl.premade.CalibratedLattice(model_config=model_config)\n",
        "keras.utils.plot_model(\n",
        "    nomon_lattice_model, expand_nested=True, show_layer_names=False, rankdir=\"LR\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Oc5f-6zNtyxr"
      },
      "outputs": [],
      "source": [
        "nomon_lattice_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[\n",
        "        keras.metrics.BinaryAccuracy(name='accuracy'),\n",
        "    ],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATES),\n",
        ")\n",
        "nomon_lattice_model.fit(datasets['law_train'], epochs=NUM_EPOCHS, verbose=0)\n",
        "\n",
        "train_acc = nomon_lattice_model.evaluate(datasets['law_train'])[1]\n",
        "val_acc = nomon_lattice_model.evaluate(datasets['law_val'])[1]\n",
        "test_acc = nomon_lattice_model.evaluate(datasets['law_test'])[1]\n",
        "print(\n",
        "    'accuracies for train: %f, val: %f, test: %f'\n",
        "    % (train_acc, val_acc, test_acc)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LuFxP9lDTZup"
      },
      "outputs": [],
      "source": [
        "plot_model_contour(nomon_lattice_model, from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eKVkjHg_LaWb"
      },
      "source": [
        "## Train monotonic calibrated lattice model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W42OXWLVwx3w"
      },
      "source": [
        "We can get a monotonic model by setting the monotonicity constraints in feature configs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XeOKlPRc0BQe"
      },
      "outputs": [],
      "source": [
        "model_config.feature_configs[0].monotonicity = 1\n",
        "model_config.feature_configs[1].monotonicity = 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C_MUEvGNp6g2"
      },
      "outputs": [],
      "source": [
        "mon_lattice_model = tfl.premade.CalibratedLattice(model_config=model_config)\n",
        "\n",
        "mon_lattice_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[\n",
        "        keras.metrics.BinaryAccuracy(name='accuracy'),\n",
        "    ],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATES),\n",
        ")\n",
        "mon_lattice_model.fit(datasets['law_train'], epochs=NUM_EPOCHS, verbose=0)\n",
        "\n",
        "train_acc = mon_lattice_model.evaluate(datasets['law_train'])[1]\n",
        "val_acc = mon_lattice_model.evaluate(datasets['law_val'])[1]\n",
        "test_acc = mon_lattice_model.evaluate(datasets['law_test'])[1]\n",
        "print(\n",
        "    'accuracies for train: %f, val: %f, test: %f'\n",
        "    % (train_acc, val_acc, test_acc)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ABdhYOUVCXzD"
      },
      "outputs": [],
      "source": [
        "plot_model_contour(mon_lattice_model, from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GWzBEV_p0WE-"
      },
      "source": [
        "We demonstrated that TFL calibrated lattice models could be trained to be\n",
        "monotonic in both LSAT score and GPA without too big of a sacrifice in accuracy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fsI14lrFxRha"
      },
      "source": [
        "## Train other unconstrained models\n",
        "\n",
        "How does the calibrated lattice model compare to other types of models, like\n",
        "deep neural networks (DNNs) or gradient boosted trees (GBTs)? Do DNNs and GBTs\n",
        "appear to have reasonably fair outputs? To address this question, we will next\n",
        "train an unconstrained DNN and GBT. In fact, we will observe that the DNN and\n",
        "GBT both easily violate monotonicity in LSAT score and undergraduate GPA."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uo1ruWXcvUqb"
      },
      "source": [
        "### Train an unconstrained Deep Neural Network (DNN) model\n",
        "\n",
        "The architecture was previously optimized to achieve high validation accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3pplraob0Od-"
      },
      "outputs": [],
      "source": [
        "keras.utils.set_random_seed(42)\n",
        "inputs = [\n",
        "    keras.Input(shape=(1,), dtype=tf.float32),\n",
        "    keras.Input(shape=(1), dtype=tf.float32),\n",
        "]\n",
        "inputs_flat = keras.layers.Concatenate()(inputs)\n",
        "dense_layers = keras.Sequential(\n",
        "    [\n",
        "        keras.layers.Dense(64, activation='relu'),\n",
        "        keras.layers.Dense(32, activation='relu'),\n",
        "        keras.layers.Dense(1, activation=None),\n",
        "    ],\n",
        "    name='dense_layers',\n",
        ")\n",
        "dnn_model = keras.Model(inputs=inputs, outputs=dense_layers(inputs_flat))\n",
        "dnn_model.compile(\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.BinaryAccuracy(name='accuracy')],\n",
        "    optimizer=keras.optimizers.Adam(LEARNING_RATES),\n",
        ")\n",
        "dnn_model.fit(datasets['law_train'], epochs=NUM_EPOCHS, verbose=0)\n",
        "\n",
        "train_acc = dnn_model.evaluate(datasets['law_train'])[1]\n",
        "val_acc = dnn_model.evaluate(datasets['law_val'])[1]\n",
        "test_acc = dnn_model.evaluate(datasets['law_test'])[1]\n",
        "print(\n",
        "    'accuracies for train: %f, val: %f, test: %f'\n",
        "    % (train_acc, val_acc, test_acc)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LwPQqLt-E7R4"
      },
      "outputs": [],
      "source": [
        "plot_model_contour(dnn_model, from_logits=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OOAKK0_3vWir"
      },
      "source": [
        "### Train an unconstrained Gradient Boosted Trees (GBT) model\n",
        "\n",
        "The tree structure was previously optimized to achieve high validation accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6UrCJHqhgd3o"
      },
      "outputs": [],
      "source": [
        "tree_model = tfdf.keras.GradientBoostedTreesModel(\n",
        "    exclude_non_specified_features=False,\n",
        "    num_threads=1,\n",
        "    num_trees=20,\n",
        "    max_depth=4,\n",
        "    growing_strategy='BEST_FIRST_GLOBAL',\n",
        "    random_seed=42,\n",
        "    temp_directory=tempfile.mkdtemp(),\n",
        ")\n",
        "tree_model.compile(metrics=[keras.metrics.BinaryAccuracy(name='accuracy')])\n",
        "tree_model.fit(\n",
        "    datasets['law_train'], validation_data=datasets['law_val'], verbose=0\n",
        ")\n",
        "\n",
        "tree_train_acc = tree_model.evaluate(datasets['law_train'], verbose=0)[1]\n",
        "tree_val_acc = tree_model.evaluate(datasets['law_val'], verbose=0)[1]\n",
        "tree_test_acc = tree_model.evaluate(datasets['law_test'], verbose=0)[1]\n",
        "print(\n",
        "    'accuracies for GBT: train: %f, val: %f, test: %f'\n",
        "    % (tree_train_acc, tree_val_acc, tree_test_acc)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AZFyfQT1E_nR"
      },
      "outputs": [],
      "source": [
        "plot_model_contour(tree_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uX2qiMlrY8aO"
      },
      "source": [
        "# Case study #2: Credit Default\n",
        "\n",
        "The second case study that we will consider in this tutorial is predicting an\n",
        "i
Download .txt
gitextract_apwg_kre/

├── .gitmodules
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── WORKSPACE
├── docs/
│   ├── _book.yaml
│   ├── _index.yaml
│   ├── build_docs.py
│   ├── install.md
│   ├── overview.md
│   └── tutorials/
│       ├── aggregate_function_models.ipynb
│       ├── keras_layers.ipynb
│       ├── premade_models.ipynb
│       ├── shape_constraints.ipynb
│       └── shape_constraints_for_ethics.ipynb
├── examples/
│   ├── BUILD
│   ├── keras_functional_uci_heart.py
│   └── keras_sequential_uci_heart.py
├── setup.py
└── tensorflow_lattice/
    ├── BUILD
    ├── __init__.py
    ├── layers/
    │   └── __init__.py
    └── python/
        ├── BUILD
        ├── __init__.py
        ├── aggregation_layer.py
        ├── aggregation_test.py
        ├── categorical_calibration_layer.py
        ├── categorical_calibration_lib.py
        ├── categorical_calibration_test.py
        ├── cdf_layer.py
        ├── cdf_test.py
        ├── conditional_cdf.py
        ├── conditional_cdf_test.py
        ├── conditional_pwl_calibration.py
        ├── conditional_pwl_calibration_test.py
        ├── configs.py
        ├── configs_test.py
        ├── internal_utils.py
        ├── internal_utils_test.py
        ├── kronecker_factored_lattice_layer.py
        ├── kronecker_factored_lattice_lib.py
        ├── kronecker_factored_lattice_test.py
        ├── lattice_layer.py
        ├── lattice_lib.py
        ├── lattice_test.py
        ├── linear_layer.py
        ├── linear_lib.py
        ├── linear_test.py
        ├── model_info.py
        ├── parallel_combination_layer.py
        ├── parallel_combination_test.py
        ├── premade.py
        ├── premade_lib.py
        ├── premade_test.py
        ├── pwl_calibration_layer.py
        ├── pwl_calibration_lib.py
        ├── pwl_calibration_test.py
        ├── rtl_layer.py
        ├── rtl_lib.py
        ├── rtl_test.py
        ├── test_utils.py
        ├── utils.py
        └── utils_test.py
Download .txt
SYMBOL INDEX (628 symbols across 42 files)

FILE: docs/build_docs.py
  function local_definitions_filter (line 55) | def local_definitions_filter(path, parent, children):
  function main (line 62) | def main(_):

FILE: examples/keras_functional_uci_heart.py
  function main (line 86) | def main(_):

FILE: examples/keras_sequential_uci_heart.py
  function main (line 80) | def main(_):

FILE: tensorflow_lattice/python/aggregation_layer.py
  class Aggregation (line 36) | class Aggregation(keras.layers.Layer):
    method __init__ (line 61) | def __init__(self, model, **kwargs):
    method call (line 78) | def call(self, x):
    method get_config (line 82) | def get_config(self):
    method from_config (line 91) | def from_config(cls, config, custom_objects=None):

FILE: tensorflow_lattice/python/aggregation_test.py
  class AggregationTest (line 40) | class AggregationTest(tf.test.TestCase):
    method testAggregationLayer (line 42) | def testAggregationLayer(self):

FILE: tensorflow_lattice/python/categorical_calibration_layer.py
  class CategoricalCalibration (line 40) | class CategoricalCalibration(keras.layers.Layer):
    method __init__ (line 97) | def __init__(self,
    method build (line 168) | def build(self, input_shape):
    method call (line 209) | def call(self, inputs):
    method compute_output_shape (line 240) | def compute_output_shape(self, input_shape):
    method get_config (line 248) | def get_config(self):
    method assert_constraints (line 268) | def assert_constraints(self, eps=1e-6):
  class CategoricalCalibrationConstraints (line 289) | class CategoricalCalibrationConstraints(keras.constraints.Constraint):
    method __init__ (line 302) | def __init__(self, output_min=None, output_max=None, monotonicities=No...
    method __call__ (line 318) | def __call__(self, w):
    method get_config (line 326) | def get_config(self):

FILE: tensorflow_lattice/python/categorical_calibration_lib.py
  function project (line 24) | def project(weights, output_min, output_max, monotonicities):
  function assert_constraints (line 64) | def assert_constraints(weights,
  function verify_hyperparameters (line 127) | def verify_hyperparameters(num_buckets=None,

FILE: tensorflow_lattice/python/categorical_calibration_test.py
  class CategoricalCalibrationLayerTest (line 39) | class CategoricalCalibrationLayerTest(parameterized.TestCase, tf.test.Te...
    method setUp (line 41) | def setUp(self):
    method _ResetAllBackends (line 48) | def _ResetAllBackends(self):
    method _ScatterXUniformly (line 52) | def _ScatterXUniformly(self, units, num_points, num_buckets,
    method _SetDefaults (line 73) | def _SetDefaults(self, config):
    method _TrainModel (line 87) | def _TrainModel(self, config):
    method testUnconstrainedNoMissingValue (line 176) | def testUnconstrainedNoMissingValue(self, y_function):
    method testUnconstrainedWithMissingValue (line 201) | def testUnconstrainedWithMissingValue(self, y_function):
    method testConstraints (line 237) | def testConstraints(self, output_min, output_max, monotonicities,
    method testCircularMonotonicites (line 275) | def testCircularMonotonicites(self):
    method testRegularizers (line 301) | def testRegularizers(self, regularizer):
    method testOutputShape (line 321) | def testOutputShape(self):

FILE: tensorflow_lattice/python/cdf_layer.py
  class CDF (line 35) | class CDF(keras.layers.Layer):
    method __init__ (line 84) | def __init__(self,
    method build (line 158) | def build(self, input_shape):
    method call (line 204) | def call(self, inputs):
    method get_config (line 245) | def get_config(self):
  function create_kernel_initializer (line 272) | def create_kernel_initializer(kernel_initializer_id):

FILE: tensorflow_lattice/python/cdf_test.py
  class CdfLayerTest (line 33) | class CdfLayerTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 35) | def setUp(self):
    method _ResetAllBackends (line 42) | def _ResetAllBackends(self):
    method _SetDefaults (line 46) | def _SetDefaults(self, config):
    method _ScatterXUniformly (line 59) | def _ScatterXUniformly(self, num_points, input_dims):
    method _ScatterXUniformlyExtendedRange (line 70) | def _ScatterXUniformlyExtendedRange(self, num_points, input_dims):
    method _TwoDMeshGrid (line 81) | def _TwoDMeshGrid(self, num_points, input_dims):
    method _TwoDMeshGridExtendedRange (line 89) | def _TwoDMeshGridExtendedRange(self, num_points, input_dims):
    method _Sin (line 97) | def _Sin(self, x):
    method _SinPlusX (line 100) | def _SinPlusX(self, x):
    method _SinPlusXNd (line 103) | def _SinPlusXNd(self, x):
    method _SinOfSum (line 106) | def _SinOfSum(self, x):
    method _Square (line 109) | def _Square(self, x):
    method _ScaledSum (line 112) | def _ScaledSum(self, x):
    method _GetTrainingInputsAndLabels (line 118) | def _GetTrainingInputsAndLabels(self, config):
    method _TrainModel (line 142) | def _TrainModel(self, config):
    method test1Dim (line 199) | def test1Dim(self, activation, reduction, input_scaling_type, expected...
    method test2Dim (line 231) | def test2Dim(self, activation, reduction, input_scaling_type, expected...
    method test5DimScaledSum (line 263) | def test5DimScaledSum(self, activation, reduction, input_scaling_type,
    method test5DimSinOfSum (line 296) | def test5DimSinOfSum(self, activation, reduction, input_scaling_type,
    method test1DimInputOutOfBounds (line 329) | def test1DimInputOutOfBounds(self, activation, reduction, input_scalin...
    method test2DimInputOutOfBounds (line 362) | def test2DimInputOutOfBounds(self, activation, reduction, input_scalin...
    method testMultiUnitOutputSparsity (line 395) | def testMultiUnitOutputSparsity(self, input_dims, units, activation,
    method testInputScalingInit (line 434) | def testInputScalingInit(self, activation, reduction, input_scaling_init,
    method testGraphSize (line 480) | def testGraphSize(self, input_dims, num_keypoints, units, activation,

FILE: tensorflow_lattice/python/conditional_cdf.py
  function _verify_cdf_params (line 40) | def _verify_cdf_params(
  function cdf_fn (line 119) | def cdf_fn(

FILE: tensorflow_lattice/python/conditional_cdf_test.py
  class CdfFnTest (line 23) | class CdfFnTest(parameterized.TestCase, tf.test.TestCase):
    method assertAllClose (line 25) | def assertAllClose(self, x, y):
    method test_compute_sigmoid (line 122) | def test_compute_sigmoid(
    method test_compute_relu6 (line 270) | def test_compute_relu6(
    method test_scaling_exp_transformation (line 309) | def test_scaling_exp_transformation(
    method test_gradient (line 476) | def test_gradient(
    method test_raise (line 701) | def test_raise(

FILE: tensorflow_lattice/python/conditional_pwl_calibration.py
  function _front_pad (line 62) | def _front_pad(x: tf.Tensor, constant_values: float) -> tf.Tensor:
  function default_keypoint_output_parameters (line 66) | def default_keypoint_output_parameters(
  function default_keypoint_input_parameters (line 114) | def default_keypoint_input_parameters(
  function _verify_pwl_calibration (line 150) | def _verify_pwl_calibration(
  function _compute_interpolation_weights (line 249) | def _compute_interpolation_weights(inputs, keypoints, lengths):
  function pwl_calibration_fn (line 270) | def pwl_calibration_fn(

FILE: tensorflow_lattice/python/conditional_pwl_calibration_test.py
  class PwlCalibrationFnTest (line 24) | class PwlCalibrationFnTest(tf.test.TestCase):
    method assertAllClose (line 26) | def assertAllClose(self, x, y):
    method assertAllGreaterEqual (line 29) | def assertAllGreaterEqual(self, a, comparison_target):
    method assertAllLessEqual (line 32) | def assertAllLessEqual(self, a, comparison_target):
    method assertAllEqual (line 35) | def assertAllEqual(self, a, comparison_target):
    method setUp (line 40) | def setUp(self):
    method test_suite_none_monotonic (line 87) | def test_suite_none_monotonic(self):
    method test_suite_increasing_monotonic (line 240) | def test_suite_increasing_monotonic(self):
    method test_gradient_step (line 399) | def test_gradient_step(self):
    method test_suite_raises (line 440) | def test_suite_raises(self):

FILE: tensorflow_lattice/python/configs.py
  class _Config (line 80) | class _Config(object):
    method __init__ (line 83) | def __init__(self, kwargs):
    method __repr__ (line 90) | def __repr__(self):
    method get_config (line 93) | def get_config(self):
    method deserialize_nested_configs (line 125) | def deserialize_nested_configs(cls, config, custom_objects=None):
  class _HasFeatureConfigs (line 161) | class _HasFeatureConfigs(object):
    method feature_config_by_name (line 164) | def feature_config_by_name(self, feature_name):
  class _HasRegularizerConfigs (line 176) | class _HasRegularizerConfigs(object):
    method regularizer_config_by_name (line 179) | def regularizer_config_by_name(self, regularizer_name):
  class CalibratedLatticeEnsembleConfig (line 194) | class CalibratedLatticeEnsembleConfig(_Config, _HasFeatureConfigs,
    method __init__ (line 275) | def __init__(self,
    method from_config (line 392) | def from_config(cls, config, custom_objects=None):
  class CalibratedLatticeConfig (line 397) | class CalibratedLatticeConfig(_Config, _HasFeatureConfigs,
    method __init__ (line 422) | def __init__(self,
    method from_config (line 504) | def from_config(cls, config, custom_objects=None):
  class CalibratedLinearConfig (line 509) | class CalibratedLinearConfig(_Config, _HasFeatureConfigs,
    method __init__ (line 536) | def __init__(self,
    method from_config (line 580) | def from_config(cls, config, custom_objects=None):
  class AggregateFunctionConfig (line 586) | class AggregateFunctionConfig(_Config, _HasFeatureConfigs,
    method __init__ (line 609) | def __init__(self,
    method from_config (line 685) | def from_config(cls, config, custom_objects=None):
  class FeatureConfig (line 690) | class FeatureConfig(_Config, _HasRegularizerConfigs):
    method __init__ (line 747) | def __init__(self,
    method from_config (line 838) | def from_config(cls, config, custom_objects=None):
  class RegularizerConfig (line 843) | class RegularizerConfig(_Config):
    method __init__ (line 922) | def __init__(self, name, l1=0.0, l2=0.0):
    method from_config (line 933) | def from_config(cls, config, custom_objects=None):
  class TrustConfig (line 938) | class TrustConfig(_Config):
    method __init__ (line 1030) | def __init__(self,
    method from_config (line 1046) | def from_config(cls, config, custom_objects=None):
  class DominanceConfig (line 1051) | class DominanceConfig(_Config):
    method __init__ (line 1085) | def __init__(self, feature_name, dominance_type='monotonic'):
    method from_config (line 1097) | def from_config(cls, config, custom_objects=None):
  class _TypeDict (line 1102) | class _TypeDict(collections.defaultdict):
    method __init__ (line 1105) | def __init__(self, hparams):
    method __contains__ (line 1110) | def __contains__(self, _):
  function apply_updates (line 1114) | def apply_updates(model_config, updates):
  function _apply_update (line 1155) | def _apply_update(node, k, v):

FILE: tensorflow_lattice/python/configs_test.py
  class ConfigsTest (line 60) | class ConfigsTest(tf.test.TestCase):
    method test_from_config (line 62) | def test_from_config(self):
    method test_updates (line 150) | def test_updates(self):

FILE: tensorflow_lattice/python/internal_utils.py
  function _topological_sort (line 28) | def _topological_sort(key_less_than_values):
  function _min_projection (line 65) | def _min_projection(weights, sorted_indices, key_less_than_values, step):
  function _max_projection (line 96) | def _max_projection(weights, sorted_indices, key_greater_than_values, st...
  function approximately_project_categorical_partial_monotonicities (line 127) | def approximately_project_categorical_partial_monotonicities(

FILE: tensorflow_lattice/python/internal_utils_test.py
  class InternalUtilsTest (line 26) | class InternalUtilsTest(parameterized.TestCase, tf.test.TestCase):
    method _ResetAllBackends (line 28) | def _ResetAllBackends(self):
    method testApproximatelyProjectCategoricalPartialMonotonicities (line 36) | def testApproximatelyProjectCategoricalPartialMonotonicities(

FILE: tensorflow_lattice/python/kronecker_factored_lattice_layer.py
  class KroneckerFactoredLattice (line 51) | class KroneckerFactoredLattice(keras.layers.Layer):
    method __init__ (line 121) | def __init__(self,
    method build (line 194) | def build(self, input_shape):
    method call (line 280) | def call(self, inputs):
    method compute_output_shape (line 292) | def compute_output_shape(self, input_shape):
    method get_config (line 302) | def get_config(self):
    method finalize_constraints (line 323) | def finalize_constraints(self):
    method assert_constraints (line 339) | def assert_constraints(self, eps=1e-6):
  function create_kernel_initializer (line 362) | def create_kernel_initializer(kernel_initializer_id,
  function create_scale_initializer (line 418) | def create_scale_initializer(scale_initializer_id, output_min, output_max):
  class KFLRandomMonotonicInitializer (line 449) | class KFLRandomMonotonicInitializer(keras.initializers.Initializer):
    method __init__ (line 454) | def __init__(self, monotonicities, init_min=0.5, init_max=1.5, seed=No...
    method __call__ (line 469) | def __call__(self, shape, scale, dtype=None, **kwargs):
    method get_config (line 489) | def get_config(self):
  class ScaleInitializer (line 500) | class ScaleInitializer(keras.initializers.Initializer):
    method __init__ (line 512) | def __init__(self, output_min, output_max):
    method __call__ (line 522) | def __call__(self, shape, dtype=None, **kwargs):
    method get_config (line 538) | def get_config(self):
  class BiasInitializer (line 547) | class BiasInitializer(keras.initializers.Initializer):
    method __init__ (line 558) | def __init__(self, output_min, output_max):
    method __call__ (line 568) | def __call__(self, shape, dtype=None, **kwargs):
    method get_config (line 583) | def get_config(self):
  class KroneckerFactoredLatticeConstraints (line 592) | class KroneckerFactoredLatticeConstraints(keras.constraints.Constraint):
    method __init__ (line 604) | def __init__(self,
    method __call__ (line 631) | def __call__(self, w):
    method get_config (line 651) | def get_config(self):
  class ScaleConstraints (line 662) | class ScaleConstraints(keras.constraints.Constraint):
    method __init__ (line 676) | def __init__(self, output_min=None, output_max=None):
    method __call__ (line 688) | def __call__(self, scale):
    method get_config (line 703) | def get_config(self):

FILE: tensorflow_lattice/python/kronecker_factored_lattice_lib.py
  function custom_reduce_prod (line 25) | def custom_reduce_prod(t, axis):
  function evaluate_with_hypercube_interpolation (line 73) | def evaluate_with_hypercube_interpolation(inputs, scale, bias, kernel, u...
  function default_init_params (line 152) | def default_init_params(output_min, output_max):
  function kfl_random_monotonic_initializer (line 165) | def kfl_random_monotonic_initializer(shape,
  function scale_initializer (line 224) | def scale_initializer(units, num_terms, output_min, output_max):
  function bias_initializer (line 256) | def bias_initializer(units, output_min, output_max, dtype=tf.float32):
  function _approximately_project_monotonicity (line 287) | def _approximately_project_monotonicity(weights, units, scale, monotonic...
  function _approximately_project_bounds (line 346) | def _approximately_project_bounds(weights, units, output_min, output_max):
  function finalize_weight_constraints (line 394) | def finalize_weight_constraints(weights, units, scale, monotonicities,
  function finalize_scale_constraints (line 447) | def finalize_scale_constraints(scale, output_min, output_max):
  function verify_hyperparameters (line 472) | def verify_hyperparameters(lattice_sizes=None,
  function _assert_monotonicity_constraints (line 552) | def _assert_monotonicity_constraints(weights, units, scale, monotonicities,
  function _assert_bound_constraints (line 600) | def _assert_bound_constraints(weights, units, scale, output_min, output_...
  function assert_constraints (line 697) | def assert_constraints(weights,

FILE: tensorflow_lattice/python/kronecker_factored_lattice_test.py
  class KroneckerFactoredLatticeTest (line 38) | class KroneckerFactoredLatticeTest(parameterized.TestCase, tf.test.TestC...
    method setUp (line 40) | def setUp(self):
    method _ResetAllBackends (line 49) | def _ResetAllBackends(self):
    method _ScatterXUniformly (line 53) | def _ScatterXUniformly(self, num_points, lattice_sizes, input_dims):
    method _ScatterXUniformlyExtendedRange (line 65) | def _ScatterXUniformlyExtendedRange(self, num_points, lattice_sizes,
    method _SameValueForAllDims (line 79) | def _SameValueForAllDims(self, num_points, lattice_sizes, input_dims):
    method _TwoDMeshGrid (line 90) | def _TwoDMeshGrid(self, num_points, lattice_sizes, input_dims):
    method _TwoDMeshGridExtendedRange (line 102) | def _TwoDMeshGridExtendedRange(self, num_points, lattice_sizes, input_...
    method _Sin (line 114) | def _Sin(self, x):
    method _SinPlusX (line 117) | def _SinPlusX(self, x):
    method _SinPlusLargeX (line 120) | def _SinPlusLargeX(self, x):
    method _SinPlusXNd (line 123) | def _SinPlusXNd(self, x):
    method _SinOfSum (line 126) | def _SinOfSum(self, x):
    method _Max (line 129) | def _Max(self, x):
    method _ScaledSum (line 132) | def _ScaledSum(self, x):
    method _GetNonMonotonicInitializer (line 138) | def _GetNonMonotonicInitializer(self, weights):
    method _GetTrainingInputsAndLabels (line 156) | def _GetTrainingInputsAndLabels(self, config):
    method _SetDefaults (line 181) | def _SetDefaults(self, config):
    method _TestEnsemble (line 194) | def _TestEnsemble(self, config):
    method _TrainModel (line 212) | def _TrainModel(self, config):
    method testMonotonicityOneD (line 276) | def testMonotonicityOneD(self):
    method testMonotonicityTwoD (line 338) | def testMonotonicityTwoD(self):
    method testMonotonicity5d (line 417) | def testMonotonicity5d(self):
    method testMonotonicityEquivalence (line 481) | def testMonotonicityEquivalence(self, monotonicities):
    method testMonotonicity10dAlmostMonotone (line 500) | def testMonotonicity10dAlmostMonotone(self):
    method testMonotonicity10dSinOfSum (line 530) | def testMonotonicity10dSinOfSum(self):
    method testInitializerType (line 567) | def testInitializerType(self, initializer, expected_loss):
    method testAssertMonotonicity (line 588) | def testAssertMonotonicity(self):
    method testAssertBounds (line 663) | def testAssertBounds(self, output_min, output_max, kernel_initializer,
    method testOutputBounds (line 697) | def testOutputBounds(self, units, input_dims, output_min, output_max,
    method testConstraints (line 736) | def testConstraints(self, lattice_sizes, units, dims, num_terms, outpu...
    method testInputOutOfBounds (line 810) | def testInputOutOfBounds(self):
    method testHighDimensionsStressTest (line 845) | def testHighDimensionsStressTest(self):
    method testGraphSize (line 877) | def testGraphSize(self, lattice_sizes, input_dims, num_terms,
    method testCreateKernelInitializer (line 900) | def testCreateKernelInitializer(self, kernel_initializer_id, expected_...
    method testSavingLoadingScale (line 915) | def testSavingLoadingScale(self):
    method testOutputShapeForDifferentInputTypes (line 965) | def testOutputShapeForDifferentInputTypes(self, batch_size, dims, units):

FILE: tensorflow_lattice/python/lattice_layer.py
  class Lattice (line 41) | class Lattice(keras.layers.Layer):
    method __init__ (line 162) | def __init__(self,
    method build (line 370) | def build(self, input_shape):
    method call (line 442) | def call(self, inputs):
    method compute_output_shape (line 466) | def compute_output_shape(self, input_shape):
    method get_config (line 476) | def get_config(self):
    method finalize_constraints (line 505) | def finalize_constraints(self):
    method assert_constraints (line 519) | def assert_constraints(self, eps=1e-6):
  function create_kernel_initializer (line 547) | def create_kernel_initializer(kernel_initializer_id,
  class LinearInitializer (line 653) | class LinearInitializer(keras.initializers.Initializer):
    method __init__ (line 674) | def __init__(self,
    method __call__ (line 708) | def __call__(self, shape, dtype=None, partition_info=None):
    method get_config (line 728) | def get_config(self):
  class RandomMonotonicInitializer (line 740) | class RandomMonotonicInitializer(keras.initializers.Initializer):
    method __init__ (line 754) | def __init__(self, lattice_sizes, output_min, output_max, unimodalitie...
    method __call__ (line 778) | def __call__(self, shape, dtype=None, partition_info=None):
    method get_config (line 794) | def get_config(self):
  class LatticeConstraints (line 805) | class LatticeConstraints(keras.constraints.Constraint):
    method __init__ (line 817) | def __init__(self,
    method __call__ (line 883) | def __call__(self, w):
    method get_config (line 919) | def get_config(self):
  class TorsionRegularizer (line 938) | class TorsionRegularizer(keras.regularizers.Regularizer):
    method __init__ (line 967) | def __init__(self, lattice_sizes, l1=0.0, l2=0.0):
    method __call__ (line 985) | def __call__(self, x):
    method get_config (line 992) | def get_config(self):
  class LaplacianRegularizer (line 1001) | class LaplacianRegularizer(keras.regularizers.Regularizer):
    method __init__ (line 1036) | def __init__(self, lattice_sizes, l1=0.0, l2=0.0):
    method __call__ (line 1061) | def __call__(self, x):
    method get_config (line 1068) | def get_config(self):

FILE: tensorflow_lattice/python/lattice_lib.py
  function evaluate_with_simplex_interpolation (line 32) | def evaluate_with_simplex_interpolation(inputs, kernel, units, lattice_s...
  function evaluate_with_hypercube_interpolation (line 150) | def evaluate_with_hypercube_interpolation(inputs, kernel, units, lattice...
  function compute_interpolation_weights (line 193) | def compute_interpolation_weights(inputs, lattice_sizes, clip_inputs=True):
  function batch_outer_operation (line 275) | def batch_outer_operation(list_of_tensors, operation="auto"):
  function _clip_onto_lattice_range (line 337) | def _clip_onto_lattice_range(inputs, lattice_sizes):
  function _bucketize_consequtive_equal_dims (line 372) | def _bucketize_consequtive_equal_dims(inputs, lattice_sizes):
  function default_init_params (line 416) | def default_init_params(output_min, output_max):
  function linear_initializer (line 441) | def linear_initializer(lattice_sizes,
  function _linspace (line 524) | def _linspace(start, stop, num):
  function random_monotonic_initializer (line 531) | def random_monotonic_initializer(lattice_sizes,
  function _approximately_project_monotonicity (line 620) | def _approximately_project_monotonicity(weights, lattice_sizes, monotoni...
  function _approximately_project_edgeworth (line 710) | def _approximately_project_edgeworth(weights, lattice_sizes, units,
  function _approximately_project_trapezoid (line 812) | def _approximately_project_trapezoid(weights, lattice_sizes, units,
  function _trapezoid_violation_update (line 949) | def _trapezoid_violation_update(differences, units, any_edgeworth,
  function _approximately_project_bounds (line 986) | def _approximately_project_bounds(weights, units, output_min, output_max):
  function finalize_constraints (line 1039) | def finalize_constraints(weights,
  function _project_partial_monotonicity (line 1109) | def _project_partial_monotonicity(weights, lattice_sizes, monotonicities,
  function _project_partial_edgeworth (line 1218) | def _project_partial_edgeworth(weights, lattice_sizes, edgeworth_trust,
  function _project_partial_trapezoid (line 1317) | def _project_partial_trapezoid(weights, lattice_sizes, trapezoid_trust,
  function _project_partial_monotonic_dominance (line 1419) | def _project_partial_monotonic_dominance(weights, lattice_sizes,
  function _project_partial_range_dominance (line 1510) | def _project_partial_range_dominance(weights, lattice_sizes, range_domin...
  function _project_partial_joint_monotonicity (line 1603) | def _project_partial_joint_monotonicity(weights, lattice_sizes,
  function _project_partial_joint_unimodality (line 1695) | def _project_partial_joint_unimodality(weights, lattice_sizes,
  function _project_onto_hyperplane (line 1766) | def _project_onto_hyperplane(weights, joint_unimodalities, hyperplane,
  function project_by_dykstra (line 1830) | def project_by_dykstra(weights,
  function laplacian_regularizer (line 2074) | def laplacian_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):
  function torsion_regularizer (line 2157) | def torsion_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):
  function _verify_dominances_hyperparameters (line 2245) | def _verify_dominances_hyperparameters(dominances, dominance_type,
  function verify_hyperparameters (line 2291) | def verify_hyperparameters(lattice_sizes,
  function assert_constraints (line 2524) | def assert_constraints(weights,
  function _unstack_nested_lists (line 2726) | def _unstack_nested_lists(tensor_or_list, axis):
  function _unstack_nd (line 2734) | def _unstack_nd(tensor, dims):
  function _stack_nested_lists (line 2749) | def _stack_nested_lists(tensor_or_list, axis):
  function _stack_nd (line 2757) | def _stack_nd(tensor, dims):
  function _get_element (line 2772) | def _get_element(lists, indices):
  function _set_element (line 2780) | def _set_element(lists, indices, value):
  function _reverse_second_list_dimension (line 2788) | def _reverse_second_list_dimension(layers):

FILE: tensorflow_lattice/python/lattice_test.py
  class LatticeTest (line 36) | class LatticeTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 38) | def setUp(self):
    method _ResetAllBackends (line 46) | def _ResetAllBackends(self):
    method _ScatterXUniformly (line 50) | def _ScatterXUniformly(self, num_points, lattice_sizes):
    method _ScatterXUniformlyExtendedRange (line 64) | def _ScatterXUniformlyExtendedRange(self, num_points, lattice_sizes):
    method _SameValueForAllDims (line 78) | def _SameValueForAllDims(self, num_points, lattice_sizes):
    method _TwoDMeshGrid (line 93) | def _TwoDMeshGrid(self, num_points, lattice_sizes):
    method _TwoDMeshGridExtendedRange (line 105) | def _TwoDMeshGridExtendedRange(self, num_points, lattice_sizes):
    method _Sin (line 117) | def _Sin(self, x):
    method _SinPlusX (line 120) | def _SinPlusX(self, x):
    method _SinPlusLargeX (line 123) | def _SinPlusLargeX(self, x):
    method _SinPlusXNd (line 126) | def _SinPlusXNd(self, x):
    method _SinOfSum (line 129) | def _SinOfSum(self, x):
    method _Square (line 132) | def _Square(self, x):
    method _Max (line 135) | def _Max(self, x):
    method _WeightedSum (line 138) | def _WeightedSum(self, x):
    method _MixedSignWeightedSum (line 144) | def _MixedSignWeightedSum(self, x):
    method _PseudoLinear (line 151) | def _PseudoLinear(self, x):
    method _ScaledSum (line 160) | def _ScaledSum(self, x):
    method _GetMultiOutputInitializer (line 166) | def _GetMultiOutputInitializer(self, weights):
    method _GetTrainingInputsAndLabels (line 176) | def _GetTrainingInputsAndLabels(self, config):
    method _SetDefaults (line 200) | def _SetDefaults(self, config):
    method _TestEnsemble (line 223) | def _TestEnsemble(self, config):
    method _TrainModel (line 238) | def _TrainModel(self, config):
    method testMonotonicityOneD (line 308) | def testMonotonicityOneD(self):
    method testMonotonicityTwoD (line 360) | def testMonotonicityTwoD(self):
    method testMonotonicity5d (line 441) | def testMonotonicity5d(self):
    method testMonotonicityEquivalence (line 494) | def testMonotonicityEquivalence(self, monotonicities):
    method testMonotonicity10dAlmostMonotone (line 512) | def testMonotonicity10dAlmostMonotone(self):
    method testMonotonicity10dSinOfSum (line 541) | def testMonotonicity10dSinOfSum(self):
    method testSimpleTrustTwoD (line 573) | def testSimpleTrustTwoD(self, edgeworth_trusts, trapezoid_trusts,
    method testDenseTrustTwoD (line 602) | def testDenseTrustTwoD(self, edgeworth_trusts, trapezoid_trusts,
    method testSimpleTrust4D (line 632) | def testSimpleTrust4D(self, edgeworth_trusts, trapezoid_trusts,
    method testMultiDenseTrust4D (line 661) | def testMultiDenseTrust4D(self, edgeworth_trusts, trapezoid_trusts,
    method testEdgeworthTrustEquivalence (line 691) | def testEdgeworthTrustEquivalence(self, edgeworth_trusts):
    method testSimpleMonotonicDominance2D (line 717) | def testSimpleMonotonicDominance2D(self, monotonic_dominances, expecte...
    method testDenseMonotonicDominance2D (line 744) | def testDenseMonotonicDominance2D(self, monotonic_dominances, expected...
    method testDenseMonotonicDominance5D (line 771) | def testDenseMonotonicDominance5D(self, monotonic_dominances, expected...
    method testSimpleRangeDominance2D (line 799) | def testSimpleRangeDominance2D(self, range_dominances, expected_loss):
    method testDenseRangeDominance2D (line 826) | def testDenseRangeDominance2D(self, range_dominances, expected_loss, e...
    method testDenseRangeDominance5D (line 853) | def testDenseRangeDominance5D(self, range_dominances, expected_loss):
    method testSimpleJointMonotonicity2D (line 881) | def testSimpleJointMonotonicity2D(self, joint_monotonicities, expected...
    method testJointUnimodality1D (line 908) | def testJointUnimodality1D(self, joint_unimodalities, expected_loss):
    method testJointUnimodality2DSinOfSum (line 936) | def testJointUnimodality2DSinOfSum(self):
    method testJointUnimodality2DWshaped (line 966) | def testJointUnimodality2DWshaped(self, joint_unimodalities, expected_...
    method testJointUnimodality2OutOf4D (line 1008) | def testJointUnimodality2OutOf4D(self, joint_unimodalities):
    method testJointUnimodality3D (line 1062) | def testJointUnimodality3D(self):
    method testDenseJointMonotonicity2D (line 1088) | def testDenseJointMonotonicity2D(self, joint_monotonicities, expected_...
    method testDenseJointMonotonicity5D (line 1113) | def testDenseJointMonotonicity5D(self, joint_monotonicities, expected_...
    method testInitializerType (line 1144) | def testInitializerType(self, initializer, expected_loss):
    method _MergeDicts (line 1163) | def _MergeDicts(self, x, y):
    method testLinearMonotonicInitializer (line 1168) | def testLinearMonotonicInitializer(self):
    method testUnimodalInitializer (line 1248) | def testUnimodalInitializer(self):
    method testRandomMonotonicInitializer (line 1280) | def testRandomMonotonicInitializer(self):
    method testAssertMonotonicity (line 1326) | def testAssertMonotonicity(self):
    method testBounds (line 1353) | def testBounds(self):
    method testInputOutOfBounds (line 1403) | def testInputOutOfBounds(self):
    method testRegularizers2d (line 1447) | def testRegularizers2d(self, regularizer, pure_reg_loss, training_loss):
    method testRegularizersLargeLattice (line 1482) | def testRegularizersLargeLattice(self, regularizer, expected_loss):
    method testHighDimensionsStressTest (line 1498) | def testHighDimensionsStressTest(self):
    method testUnimodalityOneD (line 1539) | def testUnimodalityOneD(self, monotonicities, unimodalities, expected_...
    method testUnimodalityTwoD (line 1581) | def testUnimodalityTwoD(self, monotonicities, unimodalities, expected_...
    method testUnconstrained (line 1612) | def testUnconstrained(self):
    method testEqaulySizedDimsOptimization (line 1885) | def testEqaulySizedDimsOptimization(self, lattice_sizes, expected_loss):
    method testGraphSize (line 1907) | def testGraphSize(self, lattice_sizes, expected_graph_size):
    method testCreateKernelInitializer (line 1968) | def testCreateKernelInitializer(self, kernel_initializer_id, lattice_s...
    method testSimplexInterpolation (line 2060) | def testSimplexInterpolation(self, lattice_sizes, kernel, inputs,
    method testFinalizeConstraints (line 2117) | def testFinalizeConstraints(self, lattice_sizes, kernel, edgeworth_tru...

FILE: tensorflow_lattice/python/linear_layer.py
  class Linear (line 42) | class Linear(keras.layers.Layer):
    method __init__ (line 92) | def __init__(self,
    method build (line 197) | def build(self, input_shape):
    method call (line 273) | def call(self, inputs):
    method compute_output_shape (line 288) | def compute_output_shape(self, input_shape):
    method get_config (line 293) | def get_config(self):
    method assert_constraints (line 327) | def assert_constraints(self, eps=1e-4):
  class LinearConstraints (line 350) | class LinearConstraints(keras.constraints.Constraint):
    method __init__ (line 378) | def __init__(self, monotonicities, monotonic_dominances=None,
    method __call__ (line 403) | def __call__(self, w):
    method get_config (line 425) | def get_config(self):

FILE: tensorflow_lattice/python/linear_lib.py
  function project (line 28) | def project(weights,
  function assert_constraints (line 114) | def assert_constraints(weights,
  function verify_hyperparameters (line 211) | def verify_hyperparameters(num_input_dims=None,

FILE: tensorflow_lattice/python/linear_test.py
  class LinearTest (line 43) | class LinearTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 46) | def setUp(self):
    method _ResetAllBackends (line 50) | def _ResetAllBackends(self):
    method _ScaterXUniformly (line 54) | def _ScaterXUniformly(self, num_points, num_dims, input_min, input_max):
    method _TwoDMeshGrid (line 68) | def _TwoDMeshGrid(self, num_points, num_dims, input_min, input_max):
    method _GenLinearFunction (line 80) | def _GenLinearFunction(self, weights, bias=0.0, noise=None):
    method _SinPlusXPlusD (line 96) | def _SinPlusXPlusD(self, x):
    method _SetDefaults (line 99) | def _SetDefaults(self, config):
    method _GetTrainingInputsAndLabels (line 116) | def _GetTrainingInputsAndLabels(self, config):
    method _TrainModel (line 142) | def _TrainModel(self, config):
    method _NegateAndTrain (line 214) | def _NegateAndTrain(self, config):
    method testOneDUnconstrained (line 235) | def testOneDUnconstrained(self, use_bias, expected_loss):
    method testTwoDUnconstrained (line 256) | def testTwoDUnconstrained(self, use_bias, expected_loss):
    method testInitializers (line 277) | def testInitializers(self):
    method testAssertConstraints (line 298) | def testAssertConstraints(self):
    method testOneDMonotonicities_MonotonicInput (line 334) | def testOneDMonotonicities_MonotonicInput(self, use_bias, expected_loss):
    method testOneDMonotonicities_AntiMonotonicInput (line 357) | def testOneDMonotonicities_AntiMonotonicInput(self, use_bias, expected...
    method testOneDNormalizationOrder (line 380) | def testOneDNormalizationOrder(self, norm_order, weight):
    method testOneDNormalizationOrderZeroWeights (line 401) | def testOneDNormalizationOrderZeroWeights(self):
    method testTwoDMonotonicity (line 433) | def testTwoDMonotonicity(self, expected_loss, monotonicities):
    method testTwoDNormalizationOrder (line 476) | def testTwoDNormalizationOrder(self, norm_order, weights, monotonicities,
    method testFiveDAllConstraints (line 507) | def testFiveDAllConstraints(self, weights, monotonicities, expected_lo...
    method testTwoDMonotonicDominance (line 533) | def testTwoDMonotonicDominance(self, expected_loss, dominances):
    method testTwoDRangeDominance (line 558) | def testTwoDRangeDominance(self, dominances, monotonicities, weights,
    method testRegularizers (line 587) | def testRegularizers(self, regularizer):

FILE: tensorflow_lattice/python/model_info.py
  class ModelGraph (line 28) | class ModelGraph(
  class InputFeatureNode (line 41) | class InputFeatureNode(
  class PWLCalibrationNode (line 53) | class PWLCalibrationNode(
  class CategoricalCalibrationNode (line 69) | class CategoricalCalibrationNode(
  class LinearNode (line 82) | class LinearNode(
  class LatticeNode (line 94) | class LatticeNode(
  class KroneckerFactoredLatticeNode (line 104) | class KroneckerFactoredLatticeNode(
  class MeanNode (line 119) | class MeanNode(collections.namedtuple('MeanNode', ['input_nodes'])):

FILE: tensorflow_lattice/python/parallel_combination_layer.py
  class ParallelCombination (line 39) | class ParallelCombination(keras.layers.Layer):
    method __init__ (line 79) | def __init__(self, calibration_layers=None, single_output=True, **kwar...
    method append (line 116) | def append(self, calibration_layer):
    method build (line 120) | def build(self, input_shape):
    method call (line 135) | def call(self, inputs):
    method compute_output_shape (line 155) | def compute_output_shape(self, input_shape):
    method get_config (line 161) | def get_config(self):

FILE: tensorflow_lattice/python/parallel_combination_test.py
  class ParallelCombinationTest (line 34) | class ParallelCombinationTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 36) | def setUp(self):
    method testParallelCombinationSingleInput (line 41) | def testParallelCombinationSingleInput(self):
    method testParallelCombinationMultipleInputs (line 74) | def testParallelCombinationMultipleInputs(self):
    method testParallelCombinationClone (line 109) | def testParallelCombinationClone(self):

FILE: tensorflow_lattice/python/premade.py
  class CalibratedLatticeEnsemble (line 61) | class CalibratedLatticeEnsemble(keras.Model):
    method __init__ (line 86) | def __init__(self, model_config=None, dtype=tf.float32, **kwargs):
    method get_config (line 144) | def get_config(self):
    method from_config (line 153) | def from_config(cls, config, custom_objects=None):
  class CalibratedLattice (line 163) | class CalibratedLattice(keras.Model):
    method __init__ (line 187) | def __init__(self, model_config=None, dtype=tf.float32, **kwargs):
    method get_config (line 254) | def get_config(self):
    method from_config (line 263) | def from_config(cls, config, custom_objects=None):
  class CalibratedLinear (line 273) | class CalibratedLinear(keras.Model):
    method __init__ (line 297) | def __init__(self, model_config=None, dtype=tf.float32, **kwargs):
    method get_config (line 367) | def get_config(self):
    method from_config (line 376) | def from_config(cls, config, custom_objects=None):
  class AggregateFunction (line 388) | class AggregateFunction(keras.Model):
    method __init__ (line 410) | def __init__(self, model_config=None, dtype=tf.float32, **kwargs):
    method get_config (line 485) | def get_config(self):
    method from_config (line 494) | def from_config(cls, config, custom_objects=None):
  function get_custom_objects (line 504) | def get_custom_objects(custom_objects=None):

FILE: tensorflow_lattice/python/premade_lib.py
  function _input_calibration_regularizers (line 83) | def _input_calibration_regularizers(model_config, feature_config):
  function _middle_calibration_regularizers (line 93) | def _middle_calibration_regularizers(model_config):
  function _output_calibration_regularizers (line 102) | def _output_calibration_regularizers(model_config):
  function _lattice_regularizers (line 109) | def _lattice_regularizers(model_config, feature_configs):
  class LayerOutputRange (line 138) | class LayerOutputRange(enum.Enum):
  function _output_range (line 145) | def _output_range(layer_output_range, model_config, feature_config=None):
  function build_input_layer (line 176) | def build_input_layer(feature_configs, dtype, ragged=False):
  function build_multi_unit_calibration_layers (line 202) | def build_multi_unit_calibration_layers(calibration_input_layer,
  function build_calibration_layers (line 291) | def build_calibration_layers(calibration_input_layer, model_config,
  function build_aggregation_layer (line 360) | def build_aggregation_layer(aggregation_input_layer, model_config,
  function _monotonicities_from_feature_configs (line 441) | def _monotonicities_from_feature_configs(feature_configs):
  function _dominance_constraints_from_feature_configs (line 455) | def _dominance_constraints_from_feature_configs(feature_configs):
  function _canonical_feature_names (line 471) | def _canonical_feature_names(model_config, feature_names=None):
  function build_linear_layer (line 482) | def build_linear_layer(linear_input, feature_configs, model_config,
  function build_lattice_layer (line 535) | def build_lattice_layer(lattice_input, feature_configs, model_config,
  function build_lattice_ensemble_layer (line 654) | def build_lattice_ensemble_layer(submodels_inputs, model_config, dtype):
  function build_rtl_layer (line 689) | def build_rtl_layer(calibration_outputs, model_config, submodel_index,
  function build_calibrated_lattice_ensemble_layer (line 762) | def build_calibrated_lattice_ensemble_layer(calibration_input_layer,
  function build_linear_combination_layer (line 826) | def build_linear_combination_layer(ensemble_outputs, model_config, dtype):
  function build_output_calibration_layer (line 870) | def build_output_calibration_layer(output_calibration_input, model_config,
  function set_categorical_monotonicities (line 903) | def set_categorical_monotonicities(feature_configs):
  function set_random_lattice_ensemble (line 945) | def set_random_lattice_ensemble(model_config, feature_names=None):
  function _add_pair_to_ensemble (line 984) | def _add_pair_to_ensemble(lattices, lattice_rank, i, j):
  function _set_all_pairs_cover_lattices (line 1012) | def _set_all_pairs_cover_lattices(prefitting_model_config, feature_names):
  function construct_prefitting_model_config (line 1029) | def construct_prefitting_model_config(model_config, feature_names=None):
  function _verify_prefitting_model (line 1078) | def _verify_prefitting_model(prefitting_model, feature_names):
  function _get_lattice_weights (line 1112) | def _get_lattice_weights(prefitting_model, lattice_index):
  function _get_torsions_and_laplacians (line 1127) | def _get_torsions_and_laplacians(prefitting_model_config, prefitting_model,
  function _get_final_crystal_lattices (line 1171) | def _get_final_crystal_lattices(model_config, prefitting_model_config,
  function set_crystals_lattice_ensemble (line 1304) | def set_crystals_lattice_ensemble(model_config,
  function _weighted_quantile (line 1357) | def _weighted_quantile(sorted_values, quantiles, weights):
  function compute_keypoints (line 1392) | def compute_keypoints(values,
  function _feature_config_by_name (line 1489) | def _feature_config_by_name(feature_configs, feature_name, add_if_missing):
  function compute_feature_keypoints (line 1501) | def compute_feature_keypoints(feature_configs,
  function set_feature_keypoints (line 1536) | def set_feature_keypoints(feature_configs, feature_keypoints,
  function compute_label_keypoints (line 1547) | def compute_label_keypoints(model_config,
  function set_label_keypoints (line 1579) | def set_label_keypoints(model_config, label_keypoints):
  function _verify_ensemble_config (line 1584) | def _verify_ensemble_config(model_config):
  function _verify_kronecker_factored_config (line 1666) | def _verify_kronecker_factored_config(model_config):
  function _verify_aggregate_function_config (line 1717) | def _verify_aggregate_function_config(model_config):
  function _verify_feature_config (line 1739) | def _verify_feature_config(feature_config):
  function verify_config (line 1788) | def verify_config(model_config):

FILE: tensorflow_lattice/python/premade_test.py
  class PremadeTest (line 105) | class PremadeTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 108) | def setUp(self):
    method _ResetAllBackends (line 266) | def _ResetAllBackends(self):
    class Encoder (line 270) | class Encoder(json.JSONEncoder):
      method default (line 272) | def default(self, o):
    method testSetRandomLattices (line 279) | def testSetRandomLattices(self):
    method testSetCategoricalMonotonicities (line 306) | def testSetCategoricalMonotonicities(self):
    method testVerifyConfig (line 312) | def testVerifyConfig(self):
    method testLatticeEnsembleFromConfig (line 346) | def testLatticeEnsembleFromConfig(self):
    method testLatticeFromConfig (line 370) | def testLatticeFromConfig(self):
    method testLatticeSimplexFromConfig (line 389) | def testLatticeSimplexFromConfig(self):
    method testLinearFromConfig (line 409) | def testLinearFromConfig(self):
    method testAggregateFromConfig (line 429) | def testAggregateFromConfig(self):
    method testCalibratedLatticeEnsembleCrystals (line 456) | def testCalibratedLatticeEnsembleCrystals(self, interpolation,
    method testCalibratedLatticeEnsembleRTL (line 529) | def testCalibratedLatticeEnsembleRTL(self, interpolation, parameteriza...
    method testCalibratedLattice (line 584) | def testCalibratedLattice(self, interpolation, parameterization, num_t...
    method testLearnedCalibrationInputKeypoints (line 628) | def testLearnedCalibrationInputKeypoints(self):
    method testLatticeEnsembleH5FormatSaveLoad (line 707) | def testLatticeEnsembleH5FormatSaveLoad(self, parameterization, num_te...
    method testLatticeEnsembleRTLH5FormatSaveLoad (line 752) | def testLatticeEnsembleRTLH5FormatSaveLoad(self, parameterization, num...
    method testLatticeH5FormatSaveLoad (line 797) | def testLatticeH5FormatSaveLoad(self, parameterization, num_terms):
    method testLinearH5FormatSaveLoad (line 833) | def testLinearH5FormatSaveLoad(self):
    method testAggregateH5FormatSaveLoad (line 860) | def testAggregateH5FormatSaveLoad(self):

FILE: tensorflow_lattice/python/pwl_calibration_layer.py
  class PWLCalibration (line 49) | class PWLCalibration(keras.layers.Layer):
    method __init__ (line 102) | def __init__(self,
    method build (line 300) | def build(self, input_shape):
    method call (line 396) | def call(self, inputs):
    method compute_output_shape (line 503) | def compute_output_shape(self, input_shape):
    method get_config (line 511) | def get_config(self):
    method assert_constraints (line 538) | def assert_constraints(self, eps=1e-6):
    method keypoints_outputs (line 582) | def keypoints_outputs(self):
    method keypoints_inputs (line 589) | def keypoints_inputs(self):
  class UniformOutputInitializer (line 614) | class UniformOutputInitializer(keras.initializers.Initializer):
    method __init__ (line 627) | def __init__(self, output_min, output_max, monotonicity, keypoints=None):
    method __call__ (line 655) | def __call__(self, shape, dtype=None, partition_info=None):
    method get_config (line 678) | def get_config(self):
  class PWLCalibrationConstraints (line 688) | class PWLCalibrationConstraints(keras.constraints.Constraint):
    method __init__ (line 699) | def __init__(
    method __call__ (line 752) | def __call__(self, w):
    method get_config (line 765) | def get_config(self):
  class NaiveBoundsConstraints (line 779) | class NaiveBoundsConstraints(keras.constraints.Constraint):
    method __init__ (line 789) | def __init__(self, lower_bound=None, upper_bound=None):
    method __call__ (line 799) | def __call__(self, w):
    method get_config (line 807) | def get_config(self):
  class LaplacianRegularizer (line 815) | class LaplacianRegularizer(keras.regularizers.Regularizer):
    method __init__ (line 832) | def __init__(self, l1=0.0, l2=0.0, is_cyclic=False):
    method __call__ (line 845) | def __call__(self, x):
    method get_config (line 873) | def get_config(self):
  class HessianRegularizer (line 882) | class HessianRegularizer(keras.regularizers.Regularizer):
    method __init__ (line 904) | def __init__(self, l1=0.0, l2=0.0, is_cyclic=False):
    method __call__ (line 917) | def __call__(self, x):
    method get_config (line 953) | def get_config(self):
  class WrinkleRegularizer (line 962) | class WrinkleRegularizer(keras.regularizers.Regularizer):
    method __init__ (line 984) | def __init__(self, l1=0.0, l2=0.0, is_cyclic=False):
    method __call__ (line 997) | def __call__(self, x):
    method get_config (line 1037) | def get_config(self):

FILE: tensorflow_lattice/python/pwl_calibration_lib.py
  class BoundConstraintsType (line 27) | class BoundConstraintsType(enum.Enum):
  function convert_all_constraints (line 39) | def convert_all_constraints(output_min, output_max, clamp_min, clamp_max):
  function _convert_constraints (line 73) | def _convert_constraints(value, clamp_to_value):
  function compute_interpolation_weights (line 95) | def compute_interpolation_weights(inputs, keypoints, lengths):
  function linear_initializer (line 129) | def linear_initializer(shape,
  function _approximately_project_bounds_only (line 196) | def _approximately_project_bounds_only(bias, heights, output_min, output...
  function _project_bounds_considering_monotonicity (line 244) | def _project_bounds_considering_monotonicity(bias, heights, monotonicity,
  function _project_convexity (line 360) | def _project_convexity(heights, lengths, convexity, constraint_group):
  function _project_monotonicity (line 462) | def _project_monotonicity(heights, monotonicity):
  function project_all_constraints (line 472) | def project_all_constraints(weights,
  function _squeeze_by_scaling (line 649) | def _squeeze_by_scaling(bias, heights, monotonicity, output_min, output_...
  function _approximately_project_convexity (line 700) | def _approximately_project_convexity(heights, lengths, convexity):
  function _finalize_constraints (line 729) | def _finalize_constraints(bias, heights, monotonicity, output_min, outpu...
  function assert_constraints (line 805) | def assert_constraints(outputs,
  function verify_hyperparameters (line 888) | def verify_hyperparameters(input_keypoints=None,

FILE: tensorflow_lattice/python/pwl_calibration_test.py
  class CalibrateWithSeparateMissing (line 46) | class CalibrateWithSeparateMissing(keras.layers.Layer):
    method __init__ (line 53) | def __init__(self, calibration_layer, missing_input_value):
    method call (line 58) | def call(self, x):
  class PwlCalibrationLayerTest (line 64) | class PwlCalibrationLayerTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 66) | def setUp(self):
    method _ResetAllBackends (line 73) | def _ResetAllBackends(self):
    method _ScatterXUniformly (line 77) | def _ScatterXUniformly(self, units, num_points, input_min, input_max,
    method _ScatterXUniformlyIncludeBounds (line 94) | def _ScatterXUniformlyIncludeBounds(self, units, **kwargs):
    method _SmallWaves (line 101) | def _SmallWaves(self, x):
    method _SmallWavesPlusOne (line 105) | def _SmallWavesPlusOne(self, x):
    method _WavyParabola (line 108) | def _WavyParabola(self, x):
    method _SinCycle (line 112) | def _SinCycle(self, x):
    method _GenPWLFunction (line 116) | def _GenPWLFunction(self, input_keypoints, pwl_weights):
    method _SetDefaults (line 138) | def _SetDefaults(self, config):
    method _TrainModel (line 174) | def _TrainModel(self, config):
    method _InverseAndTrain (line 271) | def _InverseAndTrain(self, config):
    method _CreateTrainingData (line 293) | def _CreateTrainingData(self, config):
    method _CreateKerasLayer (line 306) | def _CreateKerasLayer(self, config):
    method testUnconstrainedNoMissingValue (line 339) | def testUnconstrainedNoMissingValue(self, units, one_d_input, expected...
    method testUnconstrainedWithMissingValue (line 373) | def testUnconstrainedWithMissingValue(self, units, missing_output_value,
    method testNonMonotonicFunction (line 419) | def testNonMonotonicFunction(self, units, output_min, output_max, opti...
    method testBoundsForMissing (line 458) | def testBoundsForMissing(self, units, missing_input_value, expected_lo...
    method testAllBoundsWithoutMonotonicityConstraints (line 522) | def testAllBoundsWithoutMonotonicityConstraints(self, units, output_min,
    method testMonotonicProperBounds (line 562) | def testMonotonicProperBounds(self, units, is_clamped, optimizer,
    method testMonotonicNarrowBounds (line 601) | def testMonotonicNarrowBounds(self, units, is_clamped, optimizer,
    method testMonotonicWideBounds (line 640) | def testMonotonicWideBounds(self, units, is_clamped, optimizer,
    method testAllBoundsAndMonotonicityDirection (line 798) | def testAllBoundsAndMonotonicityDirection(self, units, output_min, out...
    method testConvexitySimple (line 837) | def testConvexitySimple(self, units, convexity, expected_loss):
    method testConvexityNonUniformKeypoints (line 871) | def testConvexityNonUniformKeypoints(self, units, convexity, expected_...
    method testConvexityDifferentNumKeypoints (line 911) | def testConvexityDifferentNumKeypoints(self, units, num_keypoints,
    method testConvexityWithMonotonicityAndBounds (line 952) | def testConvexityWithMonotonicityAndBounds(self, units, monotonicity,
    method testInputKeypoints (line 989) | def testInputKeypoints(self, keypoints):
    method testIsCyclic (line 1017) | def testIsCyclic(self, units, regularizer, num_training_epoch, expecte...
    method testInitializer (line 1050) | def testInitializer(self, units, initializer, expected_loss):
    method testRegularizers (line 1094) | def testRegularizers(self, units, regularizer, pure_reg_loss, training...
    method testAssertMonotonicity (line 1130) | def testAssertMonotonicity(self):
    method testOutputShape (line 1172) | def testOutputShape(self):
    method testKeypointsInputs (line 1199) | def testKeypointsInputs(self, input_keypoints_type, input_dims, output...

FILE: tensorflow_lattice/python/rtl_layer.py
  class RTL (line 56) | class RTL(keras.layers.Layer):
    method __init__ (line 130) | def __init__(self,
    method build (line 295) | def build(self, input_shape):
    method call (line 377) | def call(self, x, **kwargs):
    method compute_output_shape (line 430) | def compute_output_shape(self, input_shape):
    method get_config (line 452) | def get_config(self):
    method finalize_constraints (line 478) | def finalize_constraints(self):
    method assert_constraints (line 492) | def assert_constraints(self, eps=1e-6):
    method _get_rtl_structure (line 509) | def _get_rtl_structure(self, input_shape):

FILE: tensorflow_lattice/python/rtl_lib.py
  function verify_hyperparameters (line 23) | def verify_hyperparameters(lattice_size,

FILE: tensorflow_lattice/python/rtl_test.py
  class RTLTest (line 35) | class RTLTest(parameterized.TestCase, tf.test.TestCase):
    method setUp (line 37) | def setUp(self):
    method testRTLInputShapes (line 42) | def testRTLInputShapes(self):
    method testRTLOutputShape (line 142) | def testRTLOutputShape(self):
    method testRTLSaveLoad (line 161) | def testRTLSaveLoad(self):

FILE: tensorflow_lattice/python/test_utils.py
  class TimeTracker (line 27) | class TimeTracker(object):
    method __init__ (line 40) | def __init__(self, list_to_append, num_steps=1):
    method __enter__ (line 44) | def __enter__(self):
    method __exit__ (line 48) | def __exit__(self, unuesd_type, unuesd_value, unuesd_traceback):
  function run_training_loop (line 54) | def run_training_loop(config,
  function two_dim_mesh_grid (line 113) | def two_dim_mesh_grid(num_points, x_min, y_min, x_max, y_max):
  function sample_uniformly (line 158) | def sample_uniformly(num_points, lower_bounds, upper_bounds):
  function get_hypercube_interpolation_fn (line 189) | def get_hypercube_interpolation_fn(coefficients):
  function get_linear_lattice_interpolation_fn (line 224) | def get_linear_lattice_interpolation_fn(lattice_sizes, monotonicities,

FILE: tensorflow_lattice/python/utils.py
  function canonicalize_convexity (line 25) | def canonicalize_convexity(convexity):
  function canonicalize_input_bounds (line 55) | def canonicalize_input_bounds(input_bounds):
  function canonicalize_monotonicity (line 85) | def canonicalize_monotonicity(monotonicity, allow_decreasing=True):
  function canonicalize_monotonicities (line 128) | def canonicalize_monotonicities(monotonicities, allow_decreasing=True):
  function canonicalize_trust (line 157) | def canonicalize_trust(trusts):
  function canonicalize_unimodalities (line 196) | def canonicalize_unimodalities(unimodalities):
  function count_non_zeros (line 232) | def count_non_zeros(*iterables):

FILE: tensorflow_lattice/python/utils_test.py
  class UtilsTest (line 25) | class UtilsTest(parameterized.TestCase, tf.test.TestCase):
    method testCanonicalizeConvexity (line 29) | def testCanonicalizeConvexity(self, convexity,
    method testInvalidConvexity (line 36) | def testInvalidConvexity(self, invalid_convexity):
    method testCanonicalizeInputBounds (line 56) | def testCanonicalizeInputBounds(self, input_bounds,
    method testInvalidInputBounds (line 64) | def testInvalidInputBounds(self, invalid_input_bounds):
    method testCanonicalizeMonotonicity (line 73) | def testCanonicalizeMonotonicity(self, monotonicity,
    method testInvalidMonotonicity (line 81) | def testInvalidMonotonicity(self, invalid_monotonicity):
    method testInvalidDecreasingMonotonicity (line 89) | def testInvalidDecreasingMonotonicity(self, invalid_monotonicity):
    method testCanonicalizeMonotonicities (line 103) | def testCanonicalizeMonotonicities(self, monotonicities,
    method testCanonicalizeTrust (line 115) | def testCanonicalizeTrust(self, trusts, expected_canonicalized_trusts):
    method testInvalidTrustDirection (line 125) | def testInvalidTrustDirection(self, invalid_trusts):
    method testInvalidTrustLength (line 138) | def testInvalidTrustLength(self, invalid_trusts):
    method testCountNonZeros (line 149) | def testCountNonZeros(self, monotonicities, unimodalities,
    method testCanonicalizeUnimodalities (line 157) | def testCanonicalizeUnimodalities(self, unimodalities,
    method testInvalidUnimoadlities (line 168) | def testInvalidUnimoadlities(self, invalid_unimodalities):
Condensed preview — 64 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,269K chars).
[
  {
    "path": ".gitmodules",
    "chars": 111,
    "preview": "[submodule \"tensorflow\"]\n\tpath = tensorflow\n\turl = https://github.com/tensorflow/tensorflow.git\n\tbranch = r1.3\n"
  },
  {
    "path": "AUTHORS",
    "chars": 229,
    "preview": "# This is the official list of TensorFlow Lattice authors for copyright purposes.\n# Names should be added to this file a"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1628,
    "preview": "<!-- Copyright 2017 The TensorFlow Lattice Authors.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou"
  },
  {
    "path": "LICENSE",
    "chars": 11358,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 2049,
    "preview": "<!-- Copyright 2020 The TensorFlow Lattice Authors.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou"
  },
  {
    "path": "WORKSPACE",
    "chars": 716,
    "preview": "# Copyright 2018 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\"); you"
  },
  {
    "path": "docs/_book.yaml",
    "chars": 1172,
    "preview": "upper_tabs:\n# Tabs left of dropdown menu\n- include: /_upper_tabs_left.yaml\n- include: /api_docs/_upper_tabs_api.yaml\n# D"
  },
  {
    "path": "docs/_index.yaml",
    "chars": 3909,
    "preview": "book_path: /lattice/_book.yaml\nproject_path: /lattice/_project.yaml\ndescription: A library for training constrained and "
  },
  {
    "path": "docs/build_docs.py",
    "chars": 2754,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "docs/install.md",
    "chars": 930,
    "preview": "# Install TensorFlow Lattice\n\nThere are several ways to set up your environment to use TensorFlow Lattice\n(TFL).\n\n*   Th"
  },
  {
    "path": "docs/overview.md",
    "chars": 10548,
    "preview": "# TensorFlow Lattice (TFL)\n\nTensorFlow Lattice is a library that implements flexible, controlled and\ninterpretable latti"
  },
  {
    "path": "docs/tutorials/aggregate_function_models.ipynb",
    "chars": 21250,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RYmPh1qB_KO2\"\n      },\n      \"sou"
  },
  {
    "path": "docs/tutorials/keras_layers.ipynb",
    "chars": 29715,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7765UFHoyGx6\"\n      },\n      \"sou"
  },
  {
    "path": "docs/tutorials/premade_models.ipynb",
    "chars": 36518,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"HZiF5lbumA7j\"\n      },\n      \"sou"
  },
  {
    "path": "docs/tutorials/shape_constraints.ipynb",
    "chars": 48463,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7765UFHoyGx6\"\n      },\n      \"sou"
  },
  {
    "path": "docs/tutorials/shape_constraints_for_ethics.ipynb",
    "chars": 39764,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"R2AxpObRncMd\"\n      },\n      \"sou"
  },
  {
    "path": "examples/BUILD",
    "chars": 1289,
    "preview": "# Copyright 2019 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# y"
  },
  {
    "path": "examples/keras_functional_uci_heart.py",
    "chars": 13269,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "examples/keras_sequential_uci_heart.py",
    "chars": 11592,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "setup.py",
    "chars": 3512,
    "preview": "# Copyright 2018 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\"); you"
  },
  {
    "path": "tensorflow_lattice/BUILD",
    "chars": 2212,
    "preview": "# Copyright 2017 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# y"
  },
  {
    "path": "tensorflow_lattice/__init__.py",
    "chars": 1919,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/layers/__init__.py",
    "chars": 1255,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/BUILD",
    "chars": 11465,
    "preview": "# Copyright 2019 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# y"
  },
  {
    "path": "tensorflow_lattice/python/__init__.py",
    "chars": 727,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/aggregation_layer.py",
    "chars": 2809,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/aggregation_test.py",
    "chars": 1922,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/categorical_calibration_layer.py",
    "chars": 12516,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/categorical_calibration_lib.py",
    "chars": 5723,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/categorical_calibration_test.py",
    "chars": 12352,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/cdf_layer.py",
    "chars": 11722,
    "preview": "# Copyright 2021 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/cdf_test.py",
    "chars": 20117,
    "preview": "# Copyright 2021 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/conditional_cdf.py",
    "chars": 10352,
    "preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/conditional_cdf_test.py",
    "chars": 24708,
    "preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/conditional_pwl_calibration.py",
    "chars": 18220,
    "preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/conditional_pwl_calibration_test.py",
    "chars": 18538,
    "preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/configs.py",
    "chars": 53795,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/configs_test.py",
    "chars": 9435,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/internal_utils.py",
    "chars": 6449,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/internal_utils_test.py",
    "chars": 1863,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/kronecker_factored_lattice_layer.py",
    "chars": 26295,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/kronecker_factored_lattice_lib.py",
    "chars": 30221,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/kronecker_factored_lattice_test.py",
    "chars": 36009,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/lattice_layer.py",
    "chars": 46125,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/lattice_lib.py",
    "chars": 123130,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/lattice_test.py",
    "chars": 74041,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/linear_layer.py",
    "chars": 17519,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/linear_lib.py",
    "chars": 18352,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/linear_test.py",
    "chars": 23036,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/model_info.py",
    "chars": 4035,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/parallel_combination_layer.py",
    "chars": 6659,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/parallel_combination_test.py",
    "chars": 5615,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/premade.py",
    "chars": 22477,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/premade_lib.py",
    "chars": 78433,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/premade_test.py",
    "chars": 34764,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/pwl_calibration_layer.py",
    "chars": 41546,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/pwl_calibration_lib.py",
    "chars": 40662,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/pwl_calibration_test.py",
    "chars": 46874,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/rtl_layer.py",
    "chars": 28473,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/rtl_lib.py",
    "chars": 4858,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/rtl_test.py",
    "chars": 7495,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/test_utils.py",
    "chars": 9046,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/utils.py",
    "chars": 8203,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "tensorflow_lattice/python/utils_test.py",
    "chars": 8940,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  }
]

About this extraction

This page contains the full source code of the tensorflow/lattice GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 64 files (1.2 MB), approximately 303.7k tokens, and a symbol index with 628 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!