[
  {
    "path": ".gitmodules",
    "content": "[submodule \"tensorflow\"]\n\tpath = tensorflow\n\turl = https://github.com/tensorflow/tensorflow.git\n\tbranch = r1.3\n"
  },
  {
    "path": "AUTHORS",
    "content": "# This is the official list of TensorFlow Lattice authors for copyright purposes.\n# Names should be added to this file as:\n# Name or Organization <email address>\n# The email address is not required for organizations.\nGoogle Inc.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "<!-- Copyright 2017 The TensorFlow Lattice Authors.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n     http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n=============================================================================-->\n# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guidelines you need to follow.\n\n## Contributor License Agreement\n\nContributions to this project must be accompanied by a Contributor License\nAgreement. You (or your employer) retain the copyright to your contribution,\nthis simply gives us permission to use and redistribute your contributions as\npart of the project. Head over to <https://cla.developers.google.com/> to see\nyour current agreements on file or to sign a new one.\n\nYou generally only need to submit a CLA once, so if you've already submitted one\n(even if it was for a different project), you probably don't need to do it\nagain.\n\n## Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more\ninformation on using pull requests.\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<!-- Copyright 2020 The TensorFlow Lattice Authors.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n     http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n=============================================================================-->\n# TensorFlow Lattice\n\nTensorFlow Lattice is a library that implements constrained and interpretable\nlattice based models. It is an implementation of\n[Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html)\nin [TensorFlow](https://www.tensorflow.org).\n\nThe library enables you to inject domain knowledge into\nthe learning process through common-sense or policy-driven shape constraints.\nThis is done using a collection of Keras layers that can satisfy constraints\nsuch as monotonicity, convexity and pairwise trust:\n\n* PWLCalibration: piecewise linear calibration of signals.\n* CategoricalCalibration: mapping of categorical inputs into real values.\n* Lattice: interpolated look-up table implementation.\n* Linear: linear function with monotonicity and norm constraints.\n\nThe library also provides easy to setup canned estimators for common use cases:\n\n* Calibrated Linear\n* Calibrated Lattice\n* Random Tiny Lattices (RTL)\n* Crystals\n\nWith TF Lattice you can use domain knowledge to better extrapolate to the parts\nof the input space not covered by the training dataset. This helps avoid\nunexpected model behaviour when the serving distribution is different from the\ntraining distribution.\n\n<div align=\"center\">\n  <img src=\"docs/images/model_comparison.png\">\n</div>\n\nYou can install our prebuilt pip package using\n\n```bash\npip install tensorflow-lattice\n```\n"
  },
  {
    "path": "WORKSPACE",
    "content": "# Copyright 2018 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\"); you may not\n# use this file except in compliance with the License. You may obtain a copy of\n# the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n# License for the specific language governing permissions and limitations under\n# the License.\n# ==============================================================================\n\nworkspace(name = \"tensorflow_lattice\")\n"
  },
  {
    "path": "docs/_book.yaml",
    "content": "upper_tabs:\n# Tabs left of dropdown menu\n- include: /_upper_tabs_left.yaml\n- include: /api_docs/_upper_tabs_api.yaml\n# Dropdown menu\n- name: Resources\n  path: /resources\n  is_default: true\n  menu:\n  - include: /resources/_menu_toc.yaml\n  lower_tabs:\n    # Subsite tabs\n    other:\n    - name: Guide & Tutorials\n      contents:\n      - title: Overview\n        path: /lattice/overview\n      - title: Install\n        path: /lattice/install\n      - heading: Tutorials\n      - title: Shape Constraints\n        path: /lattice/tutorials/shape_constraints\n      - title: Ethical Constraints for ML Fairness\n        path: /lattice/tutorials/shape_constraints_for_ethics\n      - title: Keras Layers and Custom Models\n        path: /lattice/tutorials/keras_layers\n      - title: Keras Premade Models\n        path: /lattice/tutorials/premade_models\n      - title: Aggregate Function Models\n        path: /lattice/tutorials/aggregate_function_models\n\n    - name: API\n      skip_translation: true\n      contents:\n      - title: All Symbols\n        path: /lattice/api_docs/python/tfl/all_symbols\n      - include: /lattice/api_docs/python/tfl/_toc.yaml\n\n- include: /_upper_tabs_right.yaml\n"
  },
  {
    "path": "docs/_index.yaml",
    "content": "book_path: /lattice/_book.yaml\nproject_path: /lattice/_project.yaml\ndescription: A library for training constrained and interpretable lattice based models. Inject\n domain knowledge into the learning process through constraints on Keras layers.\nlanding_page:\n  custom_css_path: /site-assets/css/style.css\n  rows:\n  - heading: Flexible, controlled and interpretable ML with lattice based models\n    items:\n    - classname: devsite-landing-row-50\n      description: >\n        <p>TensorFlow Lattice is a library that implements constrained and interpretable lattice\n        based models. The library enables you to inject domain knowledge into the learning process\n        through common-sense or policy-driven\n        <a href=\"./tutorials/shape_constraints\">shape constraints</a>. This is done using a\n        collection of <a href=\"./tutorials/keras_layers\">Keras layers</a> that can satisfy\n        constraints such as monotonicity, convexity and how features interact. The library also\n        provides easy to setup <a href=\"./tutorials/premade_models\">premade models</a>.</p>\n        <p>With TF Lattice you can use domain knowledge to better extrapolate to the parts of the\n        input space not covered by the training dataset. This helps avoid unexpected model behaviour\n        when the serving distribution is different from the training distribution.</p>\n        <figure>\n            <img src=\"images/model_comparison.png\">\n        </figure>\n\n      code_block: |\n        <pre class = \"prettyprint\">\n        import numpy as np\n        import tensorflow as tf\n        import tensorflow_lattice as tfl\n\n        model = tf.keras.models.Sequential()\n        model.add(\n            tfl.layers.ParallelCombination([\n                # Monotonic piece-wise linear calibration with bounded output\n                tfl.layers.PWLCalibration(\n                    monotonicity='increasing',\n                    input_keypoints=np.linspace(1., 5., num=20),\n                    output_min=0.0,\n                    output_max=1.0),\n                # Diminishing returns\n                tfl.layers.PWLCalibration(\n                    monotonicity='increasing',\n                    convexity='concave',\n                    input_keypoints=np.linspace(0., 200., num=20),\n                    output_min=0.0,\n                    output_max=2.0),\n                # Partially monotonic categorical calibration: calib(0) <= calib(1)\n                tfl.layers.CategoricalCalibration(\n                    num_buckets=4,\n                    output_min=0.0,\n                    output_max=1.0,\n                    monotonicities=[(0, 1)]),\n            ]))\n        model.add(\n            tfl.layers.Lattice(\n                lattice_sizes=[2, 3, 2],\n                monotonicities=['increasing', 'increasing', 'increasing'],\n                # Trust: model is more responsive to input 0 if input 1 increases\n                edgeworth_trusts=(0, 1, 'positive')))\n        model.compile(...)\n        </pre>\n\n  - classname: devsite-landing-row-cards\n    items:\n    - heading: \"TensorFlow Lattice: Flexible, controlled and interpretable ML\"\n      image_path: /resources/images/tf-logo-card-16x9.png\n      path: https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html\n      buttons:\n      - label: \"Read on the TensorFlow blog\"\n        path: https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html\n    - heading: \"TensorFlow Lattice: Control your ML with monotonicity\"\n      youtube_id: ABBnNjbjv2Q\n      buttons:\n      - label: Watch the video\n        path: https://www.youtube.com/watch?v=ABBnNjbjv2Q\n    - heading: \"TF Lattice on GitHub\"\n      image_path: /resources/images/github-card-16x9.png\n      path: https://github.com/tensorflow/lattice\n      buttons:\n      - label: \"View on GitHub\"\n        path: https://github.com/tensorflow/lattice\n"
  },
  {
    "path": "docs/build_docs.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Generate docs API for TF Lattice.\n\nExample run:\n\n```\npython build_docs.py --output_dir=/path/to/output\n```\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport sys\n\nfrom absl import app\nfrom absl import flags\n\nfrom tensorflow_docs.api_generator import generate_lib\nfrom tensorflow_docs.api_generator import public_api\n\nimport tensorflow_lattice as tfl\n\nflags.DEFINE_string('output_dir', '/tmp/tfl_api/',\n                    'The path to output the files to')\n\nflags.DEFINE_string(\n    'code_url_prefix',\n    'https://github.com/tensorflow/lattice/blob/master/tensorflow_lattice',\n    'The url prefix for links to code.')\n\nflags.DEFINE_bool('search_hints', True,\n                  'Include metadata search hints in the generated files')\n\nflags.DEFINE_string('site_path', 'lattice/api_docs/python',\n                    'Path prefix in the _toc.yaml')\n\nFLAGS = flags.FLAGS\n\n\ndef local_definitions_filter(path, parent, children):\n  \"\"\"Filters local imports, except for the tfl.layers module.\"\"\"\n  if path == ('tfl', 'layers'):\n    return children\n  return public_api.local_definitions_filter(path, parent, children)\n\n\ndef main(_):\n  private_map = {\n      'tfl': ['python'],\n      'tfl.aggregation_layer': ['Aggregation'],\n      'tfl.categorical_calibration_layer': ['CategoricalCalibration'],\n      'tfl.cdf_layer': ['CDF'],\n      'tfl.kronecker_factored_lattice_layer': ['KroneckerFactoredLattice'],\n      'tfl.lattice_layer': ['Lattice'],\n      'tfl.linear_layer': ['Linear'],\n      'tfl.pwl_calibration_layer': ['PWLCalibration'],\n      'tfl.parallel_combination_layer': ['ParallelCombination'],\n      'tfl.rtl_layer': ['RTL'],\n  }\n  doc_generator = generate_lib.DocGenerator(\n      root_title='TensorFlow Lattice 2.0',\n      py_modules=[('tfl', tfl)],\n      base_dir=os.path.dirname(tfl.__file__),\n      code_url_prefix=FLAGS.code_url_prefix,\n      search_hints=FLAGS.search_hints,\n      site_path=FLAGS.site_path,\n      private_map=private_map,\n      callbacks=[local_definitions_filter])\n\n  sys.exit(doc_generator.build(output_dir=FLAGS.output_dir))\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "docs/install.md",
    "content": "# Install TensorFlow Lattice\n\nThere are several ways to set up your environment to use TensorFlow Lattice\n(TFL).\n\n*   The easiest way to learn and use TFL requires no installation: run the any\n    of the tutorials (e.g.\n    [premade models](tutorials/premade_models.ipynb)).\n*   To use TFL on a local machine, install the `tensorflow-lattice` pip package.\n*   If you have a unique machine configuration, you can build the package from\n    source.\n\n## Install TensorFlow Lattice using pip\n\nInstall using pip.\n\n```shell\npip install --upgrade tensorflow-lattice\n```\n\nNote that you will need to have `tf_keras` package installed as well.\n\n## Build from source\n\nClone the github repo:\n\n```shell\ngit clone https://github.com/tensorflow/lattice.git\n```\n\nBuild pip package from source:\n\n```shell\npython setup.py sdist bdist_wheel --universal --release\n```\n\nInstall the package:\n\n```shell\npip install --user --upgrade /path/to/pkg.whl\n```\n"
  },
  {
    "path": "docs/overview.md",
    "content": "# TensorFlow Lattice (TFL)\n\nTensorFlow Lattice is a library that implements flexible, controlled and\ninterpretable lattice based models. The library enables you to inject domain\nknowledge into the learning process through common-sense or policy-driven\n[shape constraints](tutorials/shape_constraints.ipynb). This is done using a\ncollection of [Keras layers](tutorials/keras_layers.ipynb) that can satisfy\nconstraints such as monotonicity, convexity and pairwise trust. The library also\nprovides easy to setup [premade models](tutorials/premade_models.ipynb).\n\n## Concepts\n\nThis section is a simplified version of the description in\n[Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html)\n, JMLR 2016.\n\n### Lattices\n\nA *lattice* is an interpolated look-up table that can approximate arbitrary\ninput-output relationships in your data. It overlaps a regular grid onto your\ninput space and learns values for the output in the vertices of the grid. For a\ntest point $x$, $f(x)$ is linearly interpolated from the lattice values\nsurrounding $x$.\n\n<img src=\"images/2d_lattice.png\" style=\"display:block; margin:auto;\">\n\nThe simple example above is a function with 2 input features and 4 parameters:\n$\\theta=[0, 0.2, 0.4, 1]$, which are the function's values at the corners of the\ninput space; the rest of the function is interpolated from these parameters.\n\nThe function $f(x)$ can capture non-linear interactions between features. You\ncan think of the lattice parameters as the height of poles set in the ground on\na regular grid, and the resulting function is like cloth pulled tight against\nthe four poles.\n\nWith $D$ features and 2 vertices along each dimension, a regular lattice will\nhave $2^D$ parameters. To fit a more flexible function, you can specify a\nfiner-grained lattice over the feature space with more vertices along each\ndimension. Lattice regression functions are continuous and piecewise infinitely\ndifferentiable.\n\n### Calibration\n\nLet's say the preceding sample lattice represents a learned *user happiness*\nwith a suggested local coffee shop calculated using features:\n\n*   coffee price, in range 0 to 20 dollars\n*   distance to the user, in range 0 to 30 kilometers\n\nWe want our model to learn user happiness with a local coffee shop suggestion.\nTensorFlow Lattice models can use *piecewise linear functions* (with\n`tfl.layers.PWLCalibration`) to calibrate and normalize the input features to\nthe range accepted by the lattice: 0.0 to 1.0 in the example lattice above. The\nfollowing show examples such calibrations functions with 10 keypoints:\n\n<p align=\"center\">\n<img src=\"images/pwl_calibration_distance.png\">\n<img src=\"images/pwl_calibration_price.png\">\n</p>\n\nIt is often a good idea to use the quantiles of the features as input keypoints.\nTensorFlow Lattice [premade models](tutorials/premade_models.ipynb) can\nautomatically set the input keypoints to the feature quantiles.\n\nFor categorical features, TensorFlow Lattice provides categorical calibration\n(with `tfl.layers.CategoricalCalibration`) with similar output bounding to feed\ninto a lattice.\n\n### Ensembles\n\nThe number of parameters of a lattice layer increases exponentially with the\nnumber of input features, hence not scaling well to very high dimensions. To\novercome this limitation, TensorFlow Lattice offers ensembles of lattices that\ncombine (average) several *tiny* lattices, which enables the model to grow\nlinearly in the number of features.\n\nThe library provides two variations of these ensembles:\n\n*   **Random Tiny Lattices** (RTL): Each submodel uses a random subset of\n    features (with replacement).\n\n*   **Crystals** : The Crystals algorithm first trains a *prefitting* model that\n    estimates pairwise feature interactions. It then arranges the final ensemble\n    such that features with more non-linear interactions are in the same\n    lattices.\n\n## Why TensorFlow Lattice ?\n\nYou can find a brief introduction to TensorFlow Lattice in this\n[TF Blog post](https://blog.tensorflow.org/2020/02/tensorflow-lattice-flexible-controlled-and-interpretable-ML.html).\n\n### Interpretability\n\nSince the parameters of each layer are the output of that layer, it is easy to\nanalyze, understand and debug each part of the model.\n\n### Accurate and Flexible Models\n\nUsing fine-grained lattices, you can get *arbitrarily complex* functions with a\nsingle lattice layer. Using multiple layers of calibrators and lattices often\nwork nicely in practice and can match or outperform DNN models of similar sizes.\n\n### Common-Sense Shape Constraints\n\nReal world training data may not sufficiently represent the run-time data.\nFlexible ML solutions such as DNNs or forests often act unexpectedly and even\nwildly in parts of the input space not covered by the training data. This\nbehaviour is especially problematic when policy or fairness constraints can be\nviolated.\n\n<img src=\"images/model_comparison.png\" style=\"display:block; margin:auto;\">\n\nEven though common forms of regularization can result in more sensible\nextrapolation, standard regularizers cannot guarantee reasonable model behaviour\nacross the entire input space, especially with high-dimensional inputs.\nSwitching to simpler models with more controlled and predictable behaviour can\ncome at a severe cost to the model accuracy.\n\nTF Lattice makes it possible to keep using flexible models, but provides several\noptions to inject domain knowledge into the learning process through\nsemantically meaningful common-sense or policy-driven\n[shape constraints](tutorials/shape_constraints.ipynb):\n\n*   **Monotonicity**: You can specify that the output should only\n    increase/decrease with respect to an input. In our example, you may want to\n    specify that increased distance to a coffee shop should only decrease the\n    predicted user preference.\n\n<p align=\"center\">\n<img src=\"images/linear_fit.png\">\n<img src=\"images/flexible_fit.png\">\n<img src=\"images/regularized_fit.png\">\n<img src=\"images/monotonic_fit.png\">\n</p>\n\n*   **Convexity/Concavity**: You can specify that the function shape can be\n    convex or concave. Mixed with monotonicity, this can force the function to\n    represent diminishing returns with respect to a given feature.\n\n*   **Unimodality**: You can specify that the function should have a unique peak\n    or unique valley. This lets you represent functions that have a *sweet spot*\n    with respect to a feature.\n\n*   **Pairwise trust**: This constraint works on a pair of features and suggests\n    that one input feature semantically reflects trust in another feature. For\n    example, higher number of reviews makes you more confident in the average\n    star rating of a restaurant. The model will be more sensitive with respect\n    to the star rating (i.e. will have a larger slope with respect to the\n    rating) when the number of reviews is higher.\n\n### Controlled Flexibility with Regularizers\n\nIn addition to shape constraints, TensorFlow lattice provides a number of\nregularizers to control the flexibility and smoothness of the function for each\nlayer.\n\n*   **Laplacian Regularizer**: Outputs of the lattice/calibration\n    vertices/keypoints are regularized towards the values of their respective\n    neighbors. This results in a *flatter* function.\n\n*   **Hessian Regularizer**: This penalizes the first derivative of the PWL\n    calibration layer to make the function *more linear*.\n\n*   **Wrinkle Regularizer**: This penalizes the second derivative of the PWL\n    calibration layer to avoid sudden changes in the curvature. It makes the\n    function smoother.\n\n*   **Torsion Regularizer**: Outputs of the lattice will be regularized towards\n    preventing torsion among the features. In other words, the model will be\n    regularized towards independence between the contributions of the features.\n\n### Mix and match with other Keras layers\n\nYou can use TF Lattice layers in combination with other Keras layers to\nconstruct partially constrained or regularized models. For example, lattice or\nPWL calibration layers can be used at the last layer of deeper networks that\ninclude embeddings or other Keras layers.\n\n## Papers\n\n*   [Deontological Ethics By Monotonicity Shape Constraints](https://arxiv.org/abs/2001.11990),\n    Serena Wang, Maya Gupta, International Conference on Artificial Intelligence\n    and Statistics (AISTATS), 2020\n*   [Shape Constraints for Set Functions](http://proceedings.mlr.press/v97/cotter19a.html),\n    Andrew Cotter, Maya Gupta, H. Jiang, Erez Louidor, Jim Muller, Taman\n    Narayan, Serena Wang, Tao Zhu. International Conference on Machine Learning\n    (ICML), 2019\n*   [Diminishing Returns Shape Constraints for Interpretability and\n    Regularization](https://papers.nips.cc/paper/7916-diminishing-returns-shape-constraints-for-interpretability-and-regularization),\n    Maya Gupta, Dara Bahri, Andrew Cotter, Kevin Canini, Advances in Neural\n    Information Processing Systems (NeurIPS), 2018\n*   [Deep Lattice Networks and Partial Monotonic Functions](https://research.google.com/pubs/pub46327.html),\n    Seungil You, Kevin Canini, David Ding, Jan Pfeifer, Maya R. Gupta, Advances\n    in Neural Information Processing Systems (NeurIPS), 2017\n*   [Fast and Flexible Monotonic Functions with Ensembles of Lattices](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices),\n    Mahdi Milani Fard, Kevin Canini, Andrew Cotter, Jan Pfeifer, Maya Gupta,\n    Advances in Neural Information Processing Systems (NeurIPS), 2016\n*   [Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html),\n    Maya Gupta, Andrew Cotter, Jan Pfeifer, Konstantin Voevodski, Kevin Canini,\n    Alexander Mangylov, Wojciech Moczydlowski, Alexander van Esbroeck, Journal\n    of Machine Learning Research (JMLR), 2016\n*   [Optimized Regression for Efficient Function Evaluation](http://ieeexplore.ieee.org/document/6203580/),\n    Eric Garcia, Raman Arora, Maya R. Gupta, IEEE Transactions on Image\n    Processing, 2012\n*   [Lattice Regression](https://papers.nips.cc/paper/3694-lattice-regression),\n    Eric Garcia, Maya Gupta, Advances in Neural Information Processing Systems\n    (NeurIPS), 2009\n\n## Tutorials and API docs\n\nFor common model architectures, you can use\n[Keras premade models](tutorials/premade_models.ipynb). You can also create\ncustom models using [TF Lattice Keras layers](tutorials/keras_layers.ipynb) or\nmix and match with other Keras layers. Check out the\n[full API docs](https://www.tensorflow.org/lattice/api_docs/python/tfl) for\ndetails.\n"
  },
  {
    "path": "docs/tutorials/aggregate_function_models.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RYmPh1qB_KO2\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 The TensorFlow Authors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"oMRm3czy9tLh\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ooXoR4kx_YL9\"\n      },\n      \"source\": [\n        \"# TF Lattice Aggregate Function Models\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"BR6XNYEXEgSU\"\n      },\n      \"source\": [\n        \"\\u003ctable class=\\\"tfo-notebook-buttons\\\" align=\\\"left\\\"\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://www.tensorflow.org/lattice/tutorials/aggregate_function_models\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/tf_logo_32px.png\\\" /\\u003eView on TensorFlow.org\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/aggregate_function_models.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/colab_logo_32px.png\\\" /\\u003eRun in Google Colab\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/aggregate_function_models.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\\\" /\\u003eView source on GitHub\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca href=\\\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/aggregate_function_models.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/download_logo_32px.png\\\" /\\u003eDownload notebook\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"\\u003c/table\\u003e\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-ZfQWUmfEsyZ\"\n      },\n      \"source\": [\n        \"## Overview\\n\",\n        \"\\n\",\n        \"TFL Premade Aggregate Function Models are quick and easy ways to build TFL `keras.Model` instances for learning complex aggregation functions. This guide outlines the steps needed to construct a TFL Premade Aggregate Function Model and train/test it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"L0lgWoB6Gmk1\"\n      },\n      \"source\": [\n        \"## Setup\\n\",\n        \"\\n\",\n        \"Installing TF Lattice package:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ivwKrEdLGphZ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@test {\\\"skip\\\": true}\\n\",\n        \"!pip install -U tensorflow tf-keras tensorflow-lattice  pydot graphviz\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"VQsRKS4wGrMu\"\n      },\n      \"source\": [\n        \"Importing required packages:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"j41-kd4MGtDS\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"\\n\",\n        \"import collections\\n\",\n        \"import logging\\n\",\n        \"import numpy as np\\n\",\n        \"import pandas as pd\\n\",\n        \"import sys\\n\",\n        \"import tensorflow_lattice as tfl\\n\",\n        \"logging.disable(sys.maxsize)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"HlJH1SMx3Vul\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Use Keras 2.\\n\",\n        \"version_fn = getattr(tf.keras, \\\"version\\\", None)\\n\",\n        \"if version_fn and version_fn().startswith(\\\"3.\\\"):\\n\",\n        \"  import tf_keras as keras\\n\",\n        \"else:\\n\",\n        \"  keras = tf.keras\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ZHPohKjBIFG5\"\n      },\n      \"source\": [\n        \"Downloading the Puzzles dataset:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"VjYHpw2dSfHH\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"train_dataframe = pd.read_csv(\\n\",\n        \"    'https://raw.githubusercontent.com/wbakst/puzzles_data/master/train.csv')\\n\",\n        \"train_dataframe.head()\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"UOsgu3eIEur6\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"test_dataframe = pd.read_csv(\\n\",\n        \"    'https://raw.githubusercontent.com/wbakst/puzzles_data/master/test.csv')\\n\",\n        \"test_dataframe.head()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"XG7MPCyzVr22\"\n      },\n      \"source\": [\n        \"Extract and convert features and labels\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"bYdJicq5bBuz\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Features:\\n\",\n        \"# - star_rating       rating out of 5 stars (1-5)\\n\",\n        \"# - word_count        number of words in the review\\n\",\n        \"# - is_amazon         1 = reviewed on amazon; 0 = reviewed on artifact website\\n\",\n        \"# - includes_photo    if the review includes a photo of the puzzle\\n\",\n        \"# - num_helpful       number of people that found this review helpful\\n\",\n        \"# - num_reviews       total number of reviews for this puzzle (we construct)\\n\",\n        \"#\\n\",\n        \"# This ordering of feature names will be the exact same order that we construct\\n\",\n        \"# our model to expect.\\n\",\n        \"feature_names = [\\n\",\n        \"    'star_rating', 'word_count', 'is_amazon', 'includes_photo', 'num_helpful',\\n\",\n        \"    'num_reviews'\\n\",\n        \"]\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"kx0ZX2HR-4qb\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def extract_features(dataframe, label_name):\\n\",\n        \"  # First we extract flattened features.\\n\",\n        \"  flattened_features = {\\n\",\n        \"      feature_name: dataframe[feature_name].values.astype(float)\\n\",\n        \"      for feature_name in feature_names[:-1]\\n\",\n        \"  }\\n\",\n        \"\\n\",\n        \"  # Construct mapping from puzzle name to feature.\\n\",\n        \"  star_rating = collections.defaultdict(list)\\n\",\n        \"  word_count = collections.defaultdict(list)\\n\",\n        \"  is_amazon = collections.defaultdict(list)\\n\",\n        \"  includes_photo = collections.defaultdict(list)\\n\",\n        \"  num_helpful = collections.defaultdict(list)\\n\",\n        \"  labels = {}\\n\",\n        \"\\n\",\n        \"  # Extract each review.\\n\",\n        \"  for i in range(len(dataframe)):\\n\",\n        \"    row = dataframe.iloc[i]\\n\",\n        \"    puzzle_name = row['puzzle_name']\\n\",\n        \"    star_rating[puzzle_name].append(float(row['star_rating']))\\n\",\n        \"    word_count[puzzle_name].append(float(row['word_count']))\\n\",\n        \"    is_amazon[puzzle_name].append(float(row['is_amazon']))\\n\",\n        \"    includes_photo[puzzle_name].append(float(row['includes_photo']))\\n\",\n        \"    num_helpful[puzzle_name].append(float(row['num_helpful']))\\n\",\n        \"    labels[puzzle_name] = float(row[label_name])\\n\",\n        \"\\n\",\n        \"  # Organize data into list of list of features.\\n\",\n        \"  names = list(star_rating.keys())\\n\",\n        \"  star_rating = [star_rating[name] for name in names]\\n\",\n        \"  word_count = [word_count[name] for name in names]\\n\",\n        \"  is_amazon = [is_amazon[name] for name in names]\\n\",\n        \"  includes_photo = [includes_photo[name] for name in names]\\n\",\n        \"  num_helpful = [num_helpful[name] for name in names]\\n\",\n        \"  num_reviews = [[len(ratings)] * len(ratings) for ratings in star_rating]\\n\",\n        \"  labels = [labels[name] for name in names]\\n\",\n        \"\\n\",\n        \"  # Flatten num_reviews\\n\",\n        \"  flattened_features['num_reviews'] = [len(reviews) for reviews in num_reviews]\\n\",\n        \"\\n\",\n        \"  # Convert data into ragged tensors.\\n\",\n        \"  star_rating = tf.ragged.constant(star_rating)\\n\",\n        \"  word_count = tf.ragged.constant(word_count)\\n\",\n        \"  is_amazon = tf.ragged.constant(is_amazon)\\n\",\n        \"  includes_photo = tf.ragged.constant(includes_photo)\\n\",\n        \"  num_helpful = tf.ragged.constant(num_helpful)\\n\",\n        \"  num_reviews = tf.ragged.constant(num_reviews)\\n\",\n        \"  labels = tf.constant(labels)\\n\",\n        \"\\n\",\n        \"  # Now we can return our extracted data.\\n\",\n        \"  return (star_rating, word_count, is_amazon, includes_photo, num_helpful,\\n\",\n        \"          num_reviews), labels, flattened_features\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Nd6j_J5CbNiz\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"train_xs, train_ys, flattened_features = extract_features(train_dataframe, 'Sales12-18MonthsAgo')\\n\",\n        \"test_xs, test_ys, _ = extract_features(test_dataframe, 'SalesLastSixMonths')\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"KfHHhCRsHejl\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Let's define our label minimum and maximum.\\n\",\n        \"min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\\n\",\n        \"min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"9TwqlRirIhAq\"\n      },\n      \"source\": [\n        \"Setting the default values used for training in this guide:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"GckmXFzRIhdD\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"LEARNING_RATE = 0.1\\n\",\n        \"BATCH_SIZE = 128\\n\",\n        \"NUM_EPOCHS = 500\\n\",\n        \"MIDDLE_DIM = 3\\n\",\n        \"MIDDLE_LATTICE_SIZE = 2\\n\",\n        \"MIDDLE_KEYPOINTS = 16\\n\",\n        \"OUTPUT_KEYPOINTS = 8\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"TpDKon4oIh2W\"\n      },\n      \"source\": [\n        \"## Feature Configs\\n\",\n        \"\\n\",\n        \"Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\\n\",\n        \"\\n\",\n        \"Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists. For aggregation models, these features will automaticaly be considered and properly handled as ragged.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"_IMwcDh7Xs5n\"\n      },\n      \"source\": [\n        \"### Compute Quantiles\\n\",\n        \"\\n\",\n        \"Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"l0uYl9ZpXtW1\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def compute_quantiles(features,\\n\",\n        \"                      num_keypoints=10,\\n\",\n        \"                      clip_min=None,\\n\",\n        \"                      clip_max=None,\\n\",\n        \"                      missing_value=None):\\n\",\n        \"  # Clip min and max if desired.\\n\",\n        \"  if clip_min is not None:\\n\",\n        \"    features = np.maximum(features, clip_min)\\n\",\n        \"    features = np.append(features, clip_min)\\n\",\n        \"  if clip_max is not None:\\n\",\n        \"    features = np.minimum(features, clip_max)\\n\",\n        \"    features = np.append(features, clip_max)\\n\",\n        \"  # Make features unique.\\n\",\n        \"  unique_features = np.unique(features)\\n\",\n        \"  # Remove missing values if specified.\\n\",\n        \"  if missing_value is not None:\\n\",\n        \"    unique_features = np.delete(unique_features,\\n\",\n        \"                                np.where(unique_features == missing_value))\\n\",\n        \"  # Compute and return quantiles over unique non-missing feature values.\\n\",\n        \"  return np.quantile(\\n\",\n        \"      unique_features,\\n\",\n        \"      np.linspace(0., 1., num=num_keypoints),\\n\",\n        \"      interpolation='nearest').astype(float)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"9oYZdVeWEhf2\"\n      },\n      \"source\": [\n        \"### Defining Our Feature Configs\\n\",\n        \"\\n\",\n        \"Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"rEYlSXhTEmoh\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Feature configs are used to specify how each feature is calibrated and used.\\n\",\n        \"feature_configs = [\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='star_rating',\\n\",\n        \"        lattice_size=2,\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints=compute_quantiles(\\n\",\n        \"            flattened_features['star_rating'], num_keypoints=5),\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='word_count',\\n\",\n        \"        lattice_size=2,\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints=compute_quantiles(\\n\",\n        \"            flattened_features['word_count'], num_keypoints=5),\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='is_amazon',\\n\",\n        \"        lattice_size=2,\\n\",\n        \"        num_buckets=2,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='includes_photo',\\n\",\n        \"        lattice_size=2,\\n\",\n        \"        num_buckets=2,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='num_helpful',\\n\",\n        \"        lattice_size=2,\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints=compute_quantiles(\\n\",\n        \"            flattened_features['num_helpful'], num_keypoints=5),\\n\",\n        \"        # Larger num_helpful indicating more trust in star_rating.\\n\",\n        \"        reflects_trust_in=[\\n\",\n        \"            tfl.configs.TrustConfig(\\n\",\n        \"                feature_name=\\\"star_rating\\\", trust_type=\\\"trapezoid\\\"),\\n\",\n        \"        ],\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='num_reviews',\\n\",\n        \"        lattice_size=2,\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints=compute_quantiles(\\n\",\n        \"            flattened_features['num_reviews'], num_keypoints=5),\\n\",\n        \"    )\\n\",\n        \"]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"9zoPJRBvPdcH\"\n      },\n      \"source\": [\n        \"## Aggregate Function Model\\n\",\n        \"\\n\",\n        \"To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). An aggregate function model is constructed using the [tfl.configs.AggregateFunctionConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/AggregateFunctionConfig). It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-linear calibration.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"l_4J7EjSPiP3\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Model config defines the model structure for the aggregate function model.\\n\",\n        \"aggregate_function_model_config = tfl.configs.AggregateFunctionConfig(\\n\",\n        \"    feature_configs=feature_configs,\\n\",\n        \"    middle_dimension=MIDDLE_DIM,\\n\",\n        \"    middle_lattice_size=MIDDLE_LATTICE_SIZE,\\n\",\n        \"    middle_calibration=True,\\n\",\n        \"    middle_calibration_num_keypoints=MIDDLE_KEYPOINTS,\\n\",\n        \"    middle_monotonicity='increasing',\\n\",\n        \"    output_min=min_label,\\n\",\n        \"    output_max=max_label,\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_calibration_num_keypoints=OUTPUT_KEYPOINTS,\\n\",\n        \"    output_initialization=np.linspace(\\n\",\n        \"        min_label, max_label, num=OUTPUT_KEYPOINTS))\\n\",\n        \"# An AggregateFunction premade model constructed from the given model config.\\n\",\n        \"aggregate_function_model = tfl.premade.AggregateFunction(\\n\",\n        \"    aggregate_function_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    aggregate_function_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"4F7AwiXgWhe2\"\n      },\n      \"source\": [\n        \"The output of each Aggregation layer is the averaged output of a calibrated lattice over the ragged inputs. Here is the model used inside the first Aggregation layer:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"UM7XF6UIWo4T\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"aggregation_layers = [\\n\",\n        \"    layer for layer in aggregate_function_model.layers\\n\",\n        \"    if isinstance(layer, tfl.layers.Aggregation)\\n\",\n        \"]\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    aggregation_layers[0].model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0ohYOftgTZhq\"\n      },\n      \"source\": [\n        \"Now, as with any other [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"uB9di3-lTfMy\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"aggregate_function_model.compile(\\n\",\n        \"    loss='mae',\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"aggregate_function_model.fit(\\n\",\n        \"    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"pwZtGDR-Tzur\"\n      },\n      \"source\": [\n        \"After training our model, we can evaluate it on our test set.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RWj1YfubT0NE\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(aggregate_function_model.evaluate(test_xs, test_ys))\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"aggregate_function_models.ipynb\",\n      \"private_outputs\": true,\n      \"provenance\": [\n        {\n          \"file_id\": \"1ohMV9lhzSWZq3aH27fBAZ1Oj3wy19PI0\",\n          \"timestamp\": 1588637142053\n        }\n      ],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "docs/tutorials/keras_layers.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7765UFHoyGx6\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 The TensorFlow Authors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"KsOkK8O69PyT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ZS8z-_KeywY9\"\n      },\n      \"source\": [\n        \"# Creating Keras Models with TFL Layers\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"r61fkA2i9Y3_\"\n      },\n      \"source\": [\n        \"\\u003ctable class=\\\"tfo-notebook-buttons\\\" align=\\\"left\\\"\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://www.tensorflow.org/lattice/tutorials/keras_layers\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/tf_logo_32px.png\\\" /\\u003eView on TensorFlow.org\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/keras_layers.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/colab_logo_32px.png\\\" /\\u003eRun in Google Colab\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/keras_layers.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\\\" /\\u003eView source on GitHub\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca href=\\\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/keras_layers.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/download_logo_32px.png\\\" /\\u003eDownload notebook\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"\\u003c/table\\u003e\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ecLbJCvJSSCd\"\n      },\n      \"source\": [\n        \"##Overview\\n\",\n        \"\\n\",\n        \"You can use TFL Keras layers to construct Keras models with monotonicity and other shape constraints. This example builds and trains a calibrated lattice model for the UCI heart dataset using TFL layers.\\n\",\n        \"\\n\",\n        \"In a calibrated lattice model, each feature is transformed by a `tfl.layers.PWLCalibration` or a `tfl.layers.CategoricalCalibration` layer and the results are nonlinearly fused using a `tfl.layers.Lattice`.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"x769lI12IZXB\"\n      },\n      \"source\": [\n        \"## Setup\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fbBVAR6UeRN5\"\n      },\n      \"source\": [\n        \"Installing TF Lattice package:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"bpXjJKpSd3j4\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@test {\\\"skip\\\": true}\\n\",\n        \"!pip install -U tensorflow tf-keras tensorflow-lattice  pydot graphviz\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"jSVl9SHTeSGX\"\n      },\n      \"source\": [\n        \"Importing required packages:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"id\": \"pm0LD8iyIZXF\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"\\n\",\n        \"import logging\\n\",\n        \"import numpy as np\\n\",\n        \"import pandas as pd\\n\",\n        \"import sys\\n\",\n        \"import tensorflow_lattice as tfl\\n\",\n        \"from tensorflow import feature_column as fc\\n\",\n        \"logging.disable(sys.maxsize)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"m8TsvLIe4Az-\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Use Keras 2.\\n\",\n        \"version_fn = getattr(tf.keras, \\\"version\\\", None)\\n\",\n        \"if version_fn and version_fn().startswith(\\\"3.\\\"):\\n\",\n        \"  import tf_keras as keras\\n\",\n        \"else:\\n\",\n        \"  keras = tf.keras\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"svPuM6QNxlrH\"\n      },\n      \"source\": [\n        \"Downloading the UCI Statlog (Heart) dataset:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"id\": \"PG3pFtK-IZXM\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# UCI Statlog (Heart) dataset.\\n\",\n        \"csv_file = keras.utils.get_file(\\n\",\n        \"    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\\n\",\n        \"training_data_df = pd.read_csv(csv_file).sample(\\n\",\n        \"    frac=1.0, random_state=41).reset_index(drop=True)\\n\",\n        \"training_data_df.head()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"nKkAw12SxvGG\"\n      },\n      \"source\": [\n        \"Setting the default values used for training in this guide:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"id\": \"krAJBE-yIZXR\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"LEARNING_RATE = 0.1\\n\",\n        \"BATCH_SIZE = 128\\n\",\n        \"NUM_EPOCHS = 100\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0TGfzhPHzpix\"\n      },\n      \"source\": [\n        \"## Sequential Keras Model\\n\",\n        \"\\n\",\n        \"This example creates a Sequential Keras model and only uses TFL layers.\\n\",\n        \"\\n\",\n        \"Lattice layers expect `input[i]` to be within `[0, lattice_sizes[i] - 1.0]`, so we need to define the lattice sizes ahead of the calibration layers so we can properly specify output range of the calibration layers.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"nOQWqPAbQS3o\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Lattice layer expects input[i] to be within [0, lattice_sizes[i] - 1.0], so\\n\",\n        \"lattice_sizes = [3, 2, 2, 2, 2, 2, 2]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"W3DnEKWvQYXm\"\n      },\n      \"source\": [\n        \"We use a `tfl.layers.ParallelCombination` layer to group together calibration layers which have to be executed in parallel in order to be able to create a Sequential model.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"o_hyk5GkQfl8\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"combined_calibrators = tfl.layers.ParallelCombination()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"BPZsSUZiQiwc\"\n      },\n      \"source\": [\n        \"We create a calibration layer for each feature and add it to the parallel combination layer. For numeric features we use `tfl.layers.PWLCalibration`, and for categorical features we use `tfl.layers.CategoricalCalibration`.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"DXPc6rSGxzFZ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# ############### age ###############\\n\",\n        \"calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Every PWLCalibration layer must have keypoints of piecewise linear\\n\",\n        \"    # function specified. Easiest way to specify them is to uniformly cover\\n\",\n        \"    # entire input range by using numpy.linspace().\\n\",\n        \"    input_keypoints=np.linspace(\\n\",\n        \"        training_data_df['age'].min(), training_data_df['age'].max(), num=5),\\n\",\n        \"    # You need to ensure that input keypoints have same dtype as layer input.\\n\",\n        \"    # You can do it by setting dtype here or by providing keypoints in such\\n\",\n        \"    # format which will be converted to desired tf.dtype by default.\\n\",\n        \"    dtype=tf.float32,\\n\",\n        \"    # Output range must correspond to expected lattice input range.\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[0] - 1.0,\\n\",\n        \")\\n\",\n        \"combined_calibrators.append(calibrator)\\n\",\n        \"\\n\",\n        \"# ############### sex ###############\\n\",\n        \"# For boolean features simply specify CategoricalCalibration layer with 2\\n\",\n        \"# buckets.\\n\",\n        \"calibrator = tfl.layers.CategoricalCalibration(\\n\",\n        \"    num_buckets=2,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[1] - 1.0,\\n\",\n        \"    # Initializes all outputs to (output_min + output_max) / 2.0.\\n\",\n        \"    kernel_initializer='constant')\\n\",\n        \"combined_calibrators.append(calibrator)\\n\",\n        \"\\n\",\n        \"# ############### cp ###############\\n\",\n        \"calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Here instead of specifying dtype of layer we convert keypoints into\\n\",\n        \"    # np.float32.\\n\",\n        \"    input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32),\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[2] - 1.0,\\n\",\n        \"    monotonicity='increasing',\\n\",\n        \"    # You can specify TFL regularizers as a tuple ('regularizer name', l1, l2).\\n\",\n        \"    kernel_regularizer=('hessian', 0.0, 1e-4))\\n\",\n        \"combined_calibrators.append(calibrator)\\n\",\n        \"\\n\",\n        \"# ############### trestbps ###############\\n\",\n        \"calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Alternatively, you might want to use quantiles as keypoints instead of\\n\",\n        \"    # uniform keypoints\\n\",\n        \"    input_keypoints=np.quantile(training_data_df['trestbps'],\\n\",\n        \"                                np.linspace(0.0, 1.0, num=5)),\\n\",\n        \"    dtype=tf.float32,\\n\",\n        \"    # Together with quantile keypoints you might want to initialize piecewise\\n\",\n        \"    # linear function to have 'equal_slopes' in order for output of layer\\n\",\n        \"    # after initialization to preserve original distribution.\\n\",\n        \"    kernel_initializer='equal_slopes',\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[3] - 1.0,\\n\",\n        \"    # You might consider clamping extreme inputs of the calibrator to output\\n\",\n        \"    # bounds.\\n\",\n        \"    clamp_min=True,\\n\",\n        \"    clamp_max=True,\\n\",\n        \"    monotonicity='increasing')\\n\",\n        \"combined_calibrators.append(calibrator)\\n\",\n        \"\\n\",\n        \"# ############### chol ###############\\n\",\n        \"calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Explicit input keypoint initialization.\\n\",\n        \"    input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\\n\",\n        \"    dtype=tf.float32,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[4] - 1.0,\\n\",\n        \"    # Monotonicity of calibrator can be decreasing. Note that corresponding\\n\",\n        \"    # lattice dimension must have INCREASING monotonicity regardless of\\n\",\n        \"    # monotonicity direction of calibrator.\\n\",\n        \"    monotonicity='decreasing',\\n\",\n        \"    # Convexity together with decreasing monotonicity result in diminishing\\n\",\n        \"    # return constraint.\\n\",\n        \"    convexity='convex',\\n\",\n        \"    # You can specify list of regularizers. You are not limited to TFL\\n\",\n        \"    # regularizrs. Feel free to use any :)\\n\",\n        \"    kernel_regularizer=[('laplacian', 0.0, 1e-4),\\n\",\n        \"                        keras.regularizers.l1_l2(l1=0.001)])\\n\",\n        \"combined_calibrators.append(calibrator)\\n\",\n        \"\\n\",\n        \"# ############### fbs ###############\\n\",\n        \"calibrator = tfl.layers.CategoricalCalibration(\\n\",\n        \"    num_buckets=2,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[5] - 1.0,\\n\",\n        \"    # For categorical calibration layer monotonicity is specified for pairs\\n\",\n        \"    # of indices of categories. Output for first category in pair will be\\n\",\n        \"    # smaller than output for second category.\\n\",\n        \"    #\\n\",\n        \"    # Don't forget to set monotonicity of corresponding dimension of Lattice\\n\",\n        \"    # layer to '1'.\\n\",\n        \"    monotonicities=[(0, 1)],\\n\",\n        \"    # This initializer is identical to default one('uniform'), but has fixed\\n\",\n        \"    # seed in order to simplify experimentation.\\n\",\n        \"    kernel_initializer=keras.initializers.RandomUniform(\\n\",\n        \"        minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1))\\n\",\n        \"combined_calibrators.append(calibrator)\\n\",\n        \"\\n\",\n        \"# ############### restecg ###############\\n\",\n        \"calibrator = tfl.layers.CategoricalCalibration(\\n\",\n        \"    num_buckets=3,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[6] - 1.0,\\n\",\n        \"    # Categorical monotonicity can be partial order.\\n\",\n        \"    monotonicities=[(0, 1), (0, 2)],\\n\",\n        \"    # Categorical calibration layer supports standard Keras regularizers.\\n\",\n        \"    kernel_regularizer=keras.regularizers.l1_l2(l1=0.001),\\n\",\n        \"    kernel_initializer='constant')\\n\",\n        \"combined_calibrators.append(calibrator)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"inyNlSBeQyp7\"\n      },\n      \"source\": [\n        \"We then create a lattice layer to nonlinearly fuse the outputs of the calibrators.\\n\",\n        \"\\n\",\n        \"Note that we need to specify the monotonicity of the lattice to be increasing for required dimensions. The composition with the direction of the monotonicity in the calibration will result in the correct end-to-end direction of monotonicity. This includes partial monotonicity of CategoricalCalibration layer.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"DNCc9oBTRo6w\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"lattice = tfl.layers.Lattice(\\n\",\n        \"    lattice_sizes=lattice_sizes,\\n\",\n        \"    monotonicities=[\\n\",\n        \"        'increasing', 'none', 'increasing', 'increasing', 'increasing',\\n\",\n        \"        'increasing', 'increasing'\\n\",\n        \"    ],\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=1.0)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"T5q2InayRpDr\"\n      },\n      \"source\": [\n        \"We can then create a sequential model using the combined calibrators and lattice layers.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"xX6lroYZQy3L\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model = keras.models.Sequential()\\n\",\n        \"model.add(combined_calibrators)\\n\",\n        \"model.add(lattice)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"W3UFxD3fRzIC\"\n      },\n      \"source\": [\n        \"Training works the same as any other keras model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"2jz4JvI-RzSj\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"features = training_data_df[[\\n\",\n        \"    'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg'\\n\",\n        \"]].values.astype(np.float32)\\n\",\n        \"target = training_data_df[['target']].values.astype(np.float32)\\n\",\n        \"\\n\",\n        \"model.compile(\\n\",\n        \"    loss=keras.losses.mean_squared_error,\\n\",\n        \"    optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE))\\n\",\n        \"model.fit(\\n\",\n        \"    features,\\n\",\n        \"    target,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    validation_split=0.2,\\n\",\n        \"    shuffle=False,\\n\",\n        \"    verbose=0)\\n\",\n        \"\\n\",\n        \"model.evaluate(features, target)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RTHoW_5lxwT5\"\n      },\n      \"source\": [\n        \"## Functional Keras Model\\n\",\n        \"\\n\",\n        \"This example uses a functional API for Keras model construction.\\n\",\n        \"\\n\",\n        \"As mentioned in the previous section, lattice layers expect `input[i]` to be within `[0, lattice_sizes[i] - 1.0]`, so we need to define the lattice sizes ahead of the calibration layers so we can properly specify output range of the calibration layers.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"gJjUYvBuW1qE\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# We are going to have 2-d embedding as one of lattice inputs.\\n\",\n        \"lattice_sizes = [3, 2, 2, 3, 3, 2, 2]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Z03qY5MYW1yT\"\n      },\n      \"source\": [\n        \"For each feature, we need to create an input layer followed by a calibration layer. For numeric features we use `tfl.layers.PWLCalibration` and for categorical features we use `tfl.layers.CategoricalCalibration`.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"DCIUz8apzs0l\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_inputs = []\\n\",\n        \"lattice_inputs = []\\n\",\n        \"# ############### age ###############\\n\",\n        \"age_input = keras.layers.Input(shape=[1], name='age')\\n\",\n        \"model_inputs.append(age_input)\\n\",\n        \"age_calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Every PWLCalibration layer must have keypoints of piecewise linear\\n\",\n        \"    # function specified. Easiest way to specify them is to uniformly cover\\n\",\n        \"    # entire input range by using numpy.linspace().\\n\",\n        \"    input_keypoints=np.linspace(\\n\",\n        \"        training_data_df['age'].min(), training_data_df['age'].max(), num=5),\\n\",\n        \"    # You need to ensure that input keypoints have same dtype as layer input.\\n\",\n        \"    # You can do it by setting dtype here or by providing keypoints in such\\n\",\n        \"    # format which will be converted to desired tf.dtype by default.\\n\",\n        \"    dtype=tf.float32,\\n\",\n        \"    # Output range must correspond to expected lattice input range.\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[0] - 1.0,\\n\",\n        \"    monotonicity='increasing',\\n\",\n        \"    name='age_calib',\\n\",\n        \")(\\n\",\n        \"    age_input)\\n\",\n        \"lattice_inputs.append(age_calibrator)\\n\",\n        \"\\n\",\n        \"# ############### sex ###############\\n\",\n        \"# For boolean features simply specify CategoricalCalibration layer with 2\\n\",\n        \"# buckets.\\n\",\n        \"sex_input = keras.layers.Input(shape=[1], name='sex')\\n\",\n        \"model_inputs.append(sex_input)\\n\",\n        \"sex_calibrator = tfl.layers.CategoricalCalibration(\\n\",\n        \"    num_buckets=2,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[1] - 1.0,\\n\",\n        \"    # Initializes all outputs to (output_min + output_max) / 2.0.\\n\",\n        \"    kernel_initializer='constant',\\n\",\n        \"    name='sex_calib',\\n\",\n        \")(\\n\",\n        \"    sex_input)\\n\",\n        \"lattice_inputs.append(sex_calibrator)\\n\",\n        \"\\n\",\n        \"# ############### cp ###############\\n\",\n        \"cp_input = keras.layers.Input(shape=[1], name='cp')\\n\",\n        \"model_inputs.append(cp_input)\\n\",\n        \"cp_calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Here instead of specifying dtype of layer we convert keypoints into\\n\",\n        \"    # np.float32.\\n\",\n        \"    input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32),\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[2] - 1.0,\\n\",\n        \"    monotonicity='increasing',\\n\",\n        \"    # You can specify TFL regularizers as tuple ('regularizer name', l1, l2).\\n\",\n        \"    kernel_regularizer=('hessian', 0.0, 1e-4),\\n\",\n        \"    name='cp_calib',\\n\",\n        \")(\\n\",\n        \"    cp_input)\\n\",\n        \"lattice_inputs.append(cp_calibrator)\\n\",\n        \"\\n\",\n        \"# ############### trestbps ###############\\n\",\n        \"trestbps_input = keras.layers.Input(shape=[1], name='trestbps')\\n\",\n        \"model_inputs.append(trestbps_input)\\n\",\n        \"trestbps_calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Alternatively, you might want to use quantiles as keypoints instead of\\n\",\n        \"    # uniform keypoints\\n\",\n        \"    input_keypoints=np.quantile(training_data_df['trestbps'],\\n\",\n        \"                                np.linspace(0.0, 1.0, num=5)),\\n\",\n        \"    dtype=tf.float32,\\n\",\n        \"    # Together with quantile keypoints you might want to initialize piecewise\\n\",\n        \"    # linear function to have 'equal_slopes' in order for output of layer\\n\",\n        \"    # after initialization to preserve original distribution.\\n\",\n        \"    kernel_initializer='equal_slopes',\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[3] - 1.0,\\n\",\n        \"    # You might consider clamping extreme inputs of the calibrator to output\\n\",\n        \"    # bounds.\\n\",\n        \"    clamp_min=True,\\n\",\n        \"    clamp_max=True,\\n\",\n        \"    monotonicity='increasing',\\n\",\n        \"    name='trestbps_calib',\\n\",\n        \")(\\n\",\n        \"    trestbps_input)\\n\",\n        \"lattice_inputs.append(trestbps_calibrator)\\n\",\n        \"\\n\",\n        \"# ############### chol ###############\\n\",\n        \"chol_input = keras.layers.Input(shape=[1], name='chol')\\n\",\n        \"model_inputs.append(chol_input)\\n\",\n        \"chol_calibrator = tfl.layers.PWLCalibration(\\n\",\n        \"    # Explicit input keypoint initialization.\\n\",\n        \"    input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[4] - 1.0,\\n\",\n        \"    # Monotonicity of calibrator can be decreasing. Note that corresponding\\n\",\n        \"    # lattice dimension must have INCREASING monotonicity regardless of\\n\",\n        \"    # monotonicity direction of calibrator.\\n\",\n        \"    monotonicity='decreasing',\\n\",\n        \"    # Convexity together with decreasing monotonicity result in diminishing\\n\",\n        \"    # return constraint.\\n\",\n        \"    convexity='convex',\\n\",\n        \"    # You can specify list of regularizers. You are not limited to TFL\\n\",\n        \"    # regularizrs. Feel free to use any :)\\n\",\n        \"    kernel_regularizer=[('laplacian', 0.0, 1e-4),\\n\",\n        \"                        keras.regularizers.l1_l2(l1=0.001)],\\n\",\n        \"    name='chol_calib',\\n\",\n        \")(\\n\",\n        \"    chol_input)\\n\",\n        \"lattice_inputs.append(chol_calibrator)\\n\",\n        \"\\n\",\n        \"# ############### fbs ###############\\n\",\n        \"fbs_input = keras.layers.Input(shape=[1], name='fbs')\\n\",\n        \"model_inputs.append(fbs_input)\\n\",\n        \"fbs_calibrator = tfl.layers.CategoricalCalibration(\\n\",\n        \"    num_buckets=2,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[5] - 1.0,\\n\",\n        \"    # For categorical calibration layer monotonicity is specified for pairs\\n\",\n        \"    # of indices of categories. Output for first category in pair will be\\n\",\n        \"    # smaller than output for second category.\\n\",\n        \"    #\\n\",\n        \"    # Don't forget to set monotonicity of corresponding dimension of Lattice\\n\",\n        \"    # layer to '1'.\\n\",\n        \"    monotonicities=[(0, 1)],\\n\",\n        \"    # This initializer is identical to default one ('uniform'), but has fixed\\n\",\n        \"    # seed in order to simplify experimentation.\\n\",\n        \"    kernel_initializer=keras.initializers.RandomUniform(\\n\",\n        \"        minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1),\\n\",\n        \"    name='fbs_calib',\\n\",\n        \")(\\n\",\n        \"    fbs_input)\\n\",\n        \"lattice_inputs.append(fbs_calibrator)\\n\",\n        \"\\n\",\n        \"# ############### restecg ###############\\n\",\n        \"restecg_input = keras.layers.Input(shape=[1], name='restecg')\\n\",\n        \"model_inputs.append(restecg_input)\\n\",\n        \"restecg_calibrator = tfl.layers.CategoricalCalibration(\\n\",\n        \"    num_buckets=3,\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=lattice_sizes[6] - 1.0,\\n\",\n        \"    # Categorical monotonicity can be partial order.\\n\",\n        \"    monotonicities=[(0, 1), (0, 2)],\\n\",\n        \"    # Categorical calibration layer supports standard Keras regularizers.\\n\",\n        \"    kernel_regularizer=keras.regularizers.l1_l2(l1=0.001),\\n\",\n        \"    kernel_initializer='constant',\\n\",\n        \"    name='restecg_calib',\\n\",\n        \")(\\n\",\n        \"    restecg_input)\\n\",\n        \"lattice_inputs.append(restecg_calibrator)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Fr0k8La_YgQG\"\n      },\n      \"source\": [\n        \"We then create a lattice layer to nonlinearly fuse the outputs of the calibrators.\\n\",\n        \"\\n\",\n        \"Note that we need to specify the monotonicity of the lattice to be increasing for required dimensions. The composition with the direction of the monotonicity in the calibration will result in the correct end-to-end direction of monotonicity. This includes partial monotonicity of `tfl.layers.CategoricalCalibration` layer.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"X15RE0NybNbU\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"lattice = tfl.layers.Lattice(\\n\",\n        \"    lattice_sizes=lattice_sizes,\\n\",\n        \"    monotonicities=[\\n\",\n        \"        'increasing', 'none', 'increasing', 'increasing', 'increasing',\\n\",\n        \"        'increasing', 'increasing'\\n\",\n        \"    ],\\n\",\n        \"    output_min=0.0,\\n\",\n        \"    output_max=1.0,\\n\",\n        \"    name='lattice',\\n\",\n        \")(\\n\",\n        \"    lattice_inputs)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"31VzsnMCA9dh\"\n      },\n      \"source\": [\n        \"To add more flexibility to the model, we add an output calibration layer.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"efCP3Yx2A9n7\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_output = tfl.layers.PWLCalibration(\\n\",\n        \"    input_keypoints=np.linspace(0.0, 1.0, 5),\\n\",\n        \"    name='output_calib',\\n\",\n        \")(\\n\",\n        \"    lattice)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"1SURnNl8bNgw\"\n      },\n      \"source\": [\n        \"We can now create a model using the inputs and outputs.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"7gY-VXuYbZLa\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model = keras.models.Model(\\n\",\n        \"    inputs=model_inputs,\\n\",\n        \"    outputs=model_output)\\n\",\n        \"keras.utils.plot_model(model, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"tvFJTs94bZXK\"\n      },\n      \"source\": [\n        \"Training works the same as any other keras model. Note that, with our setup, input features are passed as separate tensors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"vMQTGbFAYgYS\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"feature_names = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg']\\n\",\n        \"features = np.split(\\n\",\n        \"    training_data_df[feature_names].values.astype(np.float32),\\n\",\n        \"    indices_or_sections=len(feature_names),\\n\",\n        \"    axis=1)\\n\",\n        \"target = training_data_df[['target']].values.astype(np.float32)\\n\",\n        \"\\n\",\n        \"model.compile(\\n\",\n        \"    loss=keras.losses.mean_squared_error,\\n\",\n        \"    optimizer=keras.optimizers.Adagrad(LEARNING_RATE))\\n\",\n        \"model.fit(\\n\",\n        \"    features,\\n\",\n        \"    target,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    validation_split=0.2,\\n\",\n        \"    shuffle=False,\\n\",\n        \"    verbose=0)\\n\",\n        \"\\n\",\n        \"model.evaluate(features, target)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"keras_layers.ipynb\",\n      \"private_outputs\": true,\n      \"provenance\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "docs/tutorials/premade_models.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"HZiF5lbumA7j\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 The TensorFlow Authors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"KsOkK8O69PyT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"eNj0_BTFk479\"\n      },\n      \"source\": [\n        \"# TF Lattice Premade Models\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"T3qE8F5toE28\"\n      },\n      \"source\": [\n        \"\\u003ctable class=\\\"tfo-notebook-buttons\\\" align=\\\"left\\\"\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://www.tensorflow.org/lattice/tutorials/premade_models\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/tf_logo_32px.png\\\" /\\u003eView on TensorFlow.org\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/colab_logo_32px.png\\\" /\\u003eRun in Google Colab\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/premade_models.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\\\" /\\u003eView source on GitHub\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca href=\\\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/premade_models.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/download_logo_32px.png\\\" /\\u003eDownload notebook\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"\\u003c/table\\u003e\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"HEuRMAUOlFZa\"\n      },\n      \"source\": [\n        \"## Overview\\n\",\n        \"\\n\",\n        \"Premade Models are quick and easy ways to build TFL `keras.Model` instances for typical use cases. This guide outlines the steps needed to construct a TFL Premade Model and train/test it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"f2--Yq21lhRe\"\n      },\n      \"source\": [\n        \"## Setup\\n\",\n        \"\\n\",\n        \"Installing TF Lattice package:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"XizqBCyXky4y\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@test {\\\"skip\\\": true}\\n\",\n        \"!pip install -U tensorflow tf-keras tensorflow-lattice  pydot graphviz\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2oKJPy5tloOB\"\n      },\n      \"source\": [\n        \"Importing required packages:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"9wZWJJggk4al\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"\\n\",\n        \"import copy\\n\",\n        \"import logging\\n\",\n        \"import numpy as np\\n\",\n        \"import pandas as pd\\n\",\n        \"import sys\\n\",\n        \"import tensorflow_lattice as tfl\\n\",\n        \"logging.disable(sys.maxsize)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"k-AAoRho3x5N\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Use Keras 2.\\n\",\n        \"version_fn = getattr(tf.keras, \\\"version\\\", None)\\n\",\n        \"if version_fn and version_fn().startswith(\\\"3.\\\"):\\n\",\n        \"  import tf_keras as keras\\n\",\n        \"else:\\n\",\n        \"  keras = tf.keras\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"oyOrtol7mW9r\"\n      },\n      \"source\": [\n        \"Setting the default values used for training in this guide:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ns8pH2AnmgAC\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"LEARNING_RATE = 0.01\\n\",\n        \"BATCH_SIZE = 128\\n\",\n        \"NUM_EPOCHS = 500\\n\",\n        \"PREFITTING_NUM_EPOCHS = 10\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kpJJSS7YmLbG\"\n      },\n      \"source\": [\n        \"Downloading the UCI Statlog (Heart) dataset:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"AYTcybljmQJm\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"heart_csv_file = keras.utils.get_file(\\n\",\n        \"    'heart.csv',\\n\",\n        \"    'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\\n\",\n        \"heart_df = pd.read_csv(heart_csv_file)\\n\",\n        \"thal_vocab_list = ['normal', 'fixed', 'reversible']\\n\",\n        \"heart_df['thal'] = heart_df['thal'].map(\\n\",\n        \"    {v: i for i, v in enumerate(thal_vocab_list)})\\n\",\n        \"heart_df = heart_df.astype(float)\\n\",\n        \"\\n\",\n        \"heart_train_size = int(len(heart_df) * 0.8)\\n\",\n        \"heart_train_dict = dict(heart_df[:heart_train_size])\\n\",\n        \"heart_test_dict = dict(heart_df[heart_train_size:])\\n\",\n        \"\\n\",\n        \"# This ordering of input features should match the feature configs. If no\\n\",\n        \"# feature config relies explicitly on the data (i.e. all are 'quantiles'),\\n\",\n        \"# then you can construct the feature_names list by simply iterating over each\\n\",\n        \"# feature config and extracting it's name.\\n\",\n        \"feature_names = [\\n\",\n        \"    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',\\n\",\n        \"    'exang', 'oldpeak', 'slope', 'ca', 'thal'\\n\",\n        \"]\\n\",\n        \"\\n\",\n        \"# Since we have some features that manually construct their input keypoints,\\n\",\n        \"# we need an index mapping of the feature names.\\n\",\n        \"feature_name_indices = {name: index for index, name in enumerate(feature_names)}\\n\",\n        \"\\n\",\n        \"label_name = 'target'\\n\",\n        \"heart_train_xs = [\\n\",\n        \"    heart_train_dict[feature_name] for feature_name in feature_names\\n\",\n        \"]\\n\",\n        \"heart_test_xs = [heart_test_dict[feature_name] for feature_name in feature_names]\\n\",\n        \"heart_train_ys = heart_train_dict[label_name]\\n\",\n        \"heart_test_ys = heart_test_dict[label_name]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Ix2elMrGmiWX\"\n      },\n      \"source\": [\n        \"## Feature Configs\\n\",\n        \"\\n\",\n        \"Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\\n\",\n        \"\\n\",\n        \"Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ePWXuDH7-1i1\"\n      },\n      \"source\": [\n        \"### Defining Our Feature Configs\\n\",\n        \"\\n\",\n        \"Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"8y27RmHIrSBn\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Features:\\n\",\n        \"# - age\\n\",\n        \"# - sex\\n\",\n        \"# - cp        chest pain type (4 values)\\n\",\n        \"# - trestbps  resting blood pressure\\n\",\n        \"# - chol      serum cholestoral in mg/dl\\n\",\n        \"# - fbs       fasting blood sugar \\u003e 120 mg/dl\\n\",\n        \"# - restecg   resting electrocardiographic results (values 0,1,2)\\n\",\n        \"# - thalach   maximum heart rate achieved\\n\",\n        \"# - exang     exercise induced angina\\n\",\n        \"# - oldpeak   ST depression induced by exercise relative to rest\\n\",\n        \"# - slope     the slope of the peak exercise ST segment\\n\",\n        \"# - ca        number of major vessels (0-3) colored by flourosopy\\n\",\n        \"# - thal      normal; fixed defect; reversable defect\\n\",\n        \"#\\n\",\n        \"# Feature configs are used to specify how each feature is calibrated and used.\\n\",\n        \"heart_feature_configs = [\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='age',\\n\",\n        \"        lattice_size=3,\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        # We must set the keypoints manually.\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints='quantiles',\\n\",\n        \"        pwl_calibration_clip_max=100,\\n\",\n        \"        # Per feature regularization.\\n\",\n        \"        regularizer_configs=[\\n\",\n        \"            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\\n\",\n        \"        ],\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='sex',\\n\",\n        \"        num_buckets=2,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='cp',\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        # Keypoints that are uniformly spaced.\\n\",\n        \"        pwl_calibration_num_keypoints=4,\\n\",\n        \"        pwl_calibration_input_keypoints=np.linspace(\\n\",\n        \"            np.min(heart_train_xs[feature_name_indices['cp']]),\\n\",\n        \"            np.max(heart_train_xs[feature_name_indices['cp']]),\\n\",\n        \"            num=4),\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='chol',\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        # Explicit input keypoints initialization.\\n\",\n        \"        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\\n\",\n        \"        # Calibration can be forced to span the full output range by clamping.\\n\",\n        \"        pwl_calibration_clamp_min=True,\\n\",\n        \"        pwl_calibration_clamp_max=True,\\n\",\n        \"        # Per feature regularization.\\n\",\n        \"        regularizer_configs=[\\n\",\n        \"            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\\n\",\n        \"        ],\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='fbs',\\n\",\n        \"        # Partial monotonicity: output(0) \\u003c= output(1)\\n\",\n        \"        monotonicity=[(0, 1)],\\n\",\n        \"        num_buckets=2,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='trestbps',\\n\",\n        \"        monotonicity='decreasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints='quantiles',\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='thalach',\\n\",\n        \"        monotonicity='decreasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints='quantiles',\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='restecg',\\n\",\n        \"        # Partial monotonicity: output(0) \\u003c= output(1), output(0) \\u003c= output(2)\\n\",\n        \"        monotonicity=[(0, 1), (0, 2)],\\n\",\n        \"        num_buckets=3,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='exang',\\n\",\n        \"        # Partial monotonicity: output(0) \\u003c= output(1)\\n\",\n        \"        monotonicity=[(0, 1)],\\n\",\n        \"        num_buckets=2,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='oldpeak',\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        pwl_calibration_num_keypoints=5,\\n\",\n        \"        pwl_calibration_input_keypoints='quantiles',\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='slope',\\n\",\n        \"        # Partial monotonicity: output(0) \\u003c= output(1), output(1) \\u003c= output(2)\\n\",\n        \"        monotonicity=[(0, 1), (1, 2)],\\n\",\n        \"        num_buckets=3,\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='ca',\\n\",\n        \"        monotonicity='increasing',\\n\",\n        \"        pwl_calibration_num_keypoints=4,\\n\",\n        \"        pwl_calibration_input_keypoints='quantiles',\\n\",\n        \"    ),\\n\",\n        \"    tfl.configs.FeatureConfig(\\n\",\n        \"        name='thal',\\n\",\n        \"        # Partial monotonicity:\\n\",\n        \"        # output(normal) \\u003c= output(fixed)\\n\",\n        \"        # output(normal) \\u003c= output(reversible)\\n\",\n        \"        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\\n\",\n        \"        num_buckets=3,\\n\",\n        \"        # We must specify the vocabulary list in order to later set the\\n\",\n        \"        # monotonicities since we used names and not indices.\\n\",\n        \"        vocabulary_list=thal_vocab_list,\\n\",\n        \"    ),\\n\",\n        \"]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-XuAnP_-vyK6\"\n      },\n      \"source\": [\n        \"## Set Monotonicities and Keypoints\\n\",\n        \"\\n\",\n        \"Next we need to make sure to properly set the monotonicities for features where we used a custom vocabulary (such as 'thal' above).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ZIn2-EfGv--m\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"tfl.premade_lib.set_categorical_monotonicities(heart_feature_configs)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fHyzh5YHyD5n\"\n      },\n      \"source\": [\n        \"Finally we can complete our feature configs by calculating and setting the keypoints.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"KJ5kKd-lyJhZ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\\n\",\n        \"    feature_configs=heart_feature_configs, features=heart_train_dict)\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=heart_feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Mx50YgWMcxC4\"\n      },\n      \"source\": [\n        \"## Calibrated Linear Model\\n\",\n        \"\\n\",\n        \"To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). A calibrated linear model is constructed using the [tfl.configs.CalibratedLinearConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLinearConfig). It applies piecewise-linear and categorical calibration on the input features, followed by a linear combination and an optional output piecewise-linear calibration. When using output calibration or when output bounds are specified, the linear layer will apply weighted averaging on calibrated inputs.\\n\",\n        \"\\n\",\n        \"This example creates a calibrated linear model on the first 5 features.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"UvMDJKqTc1vC\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Model config defines the model structure for the premade model.\\n\",\n        \"linear_model_config = tfl.configs.CalibratedLinearConfig(\\n\",\n        \"    feature_configs=heart_feature_configs[:5],\\n\",\n        \"    use_bias=True,\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_calibration_num_keypoints=10,\\n\",\n        \"    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\\n\",\n        \"    output_initialization=np.linspace(-2.0, 2.0, num=10),\\n\",\n        \"    regularizer_configs=[\\n\",\n        \"        # Regularizer for the output calibrator.\\n\",\n        \"        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\\n\",\n        \"    ])\\n\",\n        \"# A CalibratedLinear premade model constructed from the given model config.\\n\",\n        \"linear_model = tfl.premade.CalibratedLinear(linear_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"3MC3-AyX00-A\"\n      },\n      \"source\": [\n        \"Now, as with any other [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"hPlEK-yG1B-U\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"linear_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True)],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"linear_model.fit(\\n\",\n        \"    heart_train_xs[:5],\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"OG2ua0MGAkoi\"\n      },\n      \"source\": [\n        \"After training our model, we can evaluate it on our test set.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"HybGTvXxAoxV\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(linear_model.evaluate(heart_test_xs[:5], heart_test_ys))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"jAAJK-wlc15S\"\n      },\n      \"source\": [\n        \"## Calibrated Lattice Model\\n\",\n        \"\\n\",\n        \"A calibrated lattice model is constructed using [tfl.configs.CalibratedLatticeConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLatticeConfig). A calibrated lattice model applies piecewise-linear and categorical calibration on the input features, followed by a lattice model and an optional output piecewise-linear calibration.\\n\",\n        \"\\n\",\n        \"This example creates a calibrated lattice model on the first 5 features.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"u7gNcrMtc4Lp\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# This is a calibrated lattice model: inputs are calibrated, then combined\\n\",\n        \"# non-linearly using a lattice layer.\\n\",\n        \"lattice_model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=heart_feature_configs[:5],\\n\",\n        \"    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\\n\",\n        \"    output_initialization=[-2.0, 2.0],\\n\",\n        \"    regularizer_configs=[\\n\",\n        \"        # Torsion regularizer applied to the lattice to make it more linear.\\n\",\n        \"        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),\\n\",\n        \"        # Globally defined calibration regularizer is applied to all features.\\n\",\n        \"        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),\\n\",\n        \"    ])\\n\",\n        \"# A CalibratedLattice premade model constructed from the given model config.\\n\",\n        \"lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"nmc3TUIIGGoH\"\n      },\n      \"source\": [\n        \"As before, we compile, fit, and evaluate our model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"vIjOQGD2Gp_Z\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"lattice_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True)],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"lattice_model.fit(\\n\",\n        \"    heart_train_xs[:5],\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\\n\",\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(lattice_model.evaluate(heart_test_xs[:5], heart_test_ys))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"bx74CD4Cc4T3\"\n      },\n      \"source\": [\n        \"## Calibrated Lattice Ensemble Model\\n\",\n        \"\\n\",\n        \"When the number of features is large, you can use an ensemble model, which creates multiple smaller lattices for subsets of the features and averages their output instead of creating just a single huge lattice. Ensemble lattice models are constructed using [tfl.configs.CalibratedLatticeEnsembleConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/CalibratedLatticeEnsembleConfig). A calibrated lattice ensemble model applies piecewise-linear and categorical calibration on the input feature, followed by an ensemble of lattice models and an optional output piecewise-linear calibration.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"mbg4lsKqnEkV\"\n      },\n      \"source\": [\n        \"### Explicit Lattice Ensemble Initialization\\n\",\n        \"\\n\",\n        \"If you already know which subsets of features you want to feed into your lattices, then you can explicitly set the lattices using feature names. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"yu8Twg8mdJ18\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# This is a calibrated lattice ensemble model: inputs are calibrated, then\\n\",\n        \"# combined non-linearly and averaged using multiple lattice layers.\\n\",\n        \"explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\\n\",\n        \"    feature_configs=heart_feature_configs,\\n\",\n        \"    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],\\n\",\n        \"              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],\\n\",\n        \"              ['restecg', 'age', 'sex']],\\n\",\n        \"    num_lattices=5,\\n\",\n        \"    lattice_rank=3,\\n\",\n        \"    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\\n\",\n        \"    output_initialization=[-2.0, 2.0])\\n\",\n        \"# A CalibratedLatticeEnsemble premade model constructed from the given\\n\",\n        \"# model config.\\n\",\n        \"explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\\n\",\n        \"    explicit_ensemble_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    explicit_ensemble_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"PJYR0i6MMDyh\"\n      },\n      \"source\": [\n        \"As before, we compile, fit, and evaluate our model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"capt98IOMHEm\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"explicit_ensemble_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True)],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"explicit_ensemble_model.fit(\\n\",\n        \"    heart_train_xs,\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\\n\",\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(explicit_ensemble_model.evaluate(heart_test_xs, heart_test_ys))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"VnI70C9gdKQw\"\n      },\n      \"source\": [\n        \"### Random Lattice Ensemble\\n\",\n        \"\\n\",\n        \"If you are not sure which subsets of features to feed into your lattices, another option is to use random subsets of features for each lattice. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"7EhWrQaPIXj8\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# This is a calibrated lattice ensemble model: inputs are calibrated, then\\n\",\n        \"# combined non-linearly and averaged using multiple lattice layers.\\n\",\n        \"random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\\n\",\n        \"    feature_configs=heart_feature_configs,\\n\",\n        \"    lattices='random',\\n\",\n        \"    num_lattices=5,\\n\",\n        \"    lattice_rank=3,\\n\",\n        \"    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\\n\",\n        \"    output_initialization=[-2.0, 2.0],\\n\",\n        \"    random_seed=42)\\n\",\n        \"# Now we must set the random lattice structure and construct the model.\\n\",\n        \"tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)\\n\",\n        \"# A CalibratedLatticeEnsemble premade model constructed from the given\\n\",\n        \"# model config.\\n\",\n        \"random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\\n\",\n        \"    random_ensemble_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    random_ensemble_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"sbxcIF0PJUDc\"\n      },\n      \"source\": [\n        \"As before, we compile, fit, and evaluate our model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"w0YdCDyGJY1G\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"random_ensemble_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True)],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"random_ensemble_model.fit(\\n\",\n        \"    heart_train_xs,\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\\n\",\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(random_ensemble_model.evaluate(heart_test_xs, heart_test_ys))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ZhJWe7fZIs4-\"\n      },\n      \"source\": [\n        \"### RTL Layer Random Lattice Ensemble\\n\",\n        \"\\n\",\n        \"When using a random lattice ensemble, you can specify that the model use a single `tfl.layers.RTL` layer. We note that `tfl.layers.RTL` only supports monotonicity constraints and must have the same lattice size for all features and no per-feature regularization. Note that using a `tfl.layers.RTL` layer lets you scale to much larger ensembles than using separate `tfl.layers.Lattice` instances.\\n\",\n        \"\\n\",\n        \"This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"0PC9oRFYJMF_\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Make sure our feature configs have the same lattice size, no per-feature\\n\",\n        \"# regularization, and only monotonicity constraints.\\n\",\n        \"rtl_layer_feature_configs = copy.deepcopy(heart_feature_configs)\\n\",\n        \"for feature_config in rtl_layer_feature_configs:\\n\",\n        \"  feature_config.lattice_size = 2\\n\",\n        \"  feature_config.unimodality = 'none'\\n\",\n        \"  feature_config.reflects_trust_in = None\\n\",\n        \"  feature_config.dominates = None\\n\",\n        \"  feature_config.regularizer_configs = None\\n\",\n        \"# This is a calibrated lattice ensemble model: inputs are calibrated, then\\n\",\n        \"# combined non-linearly and averaged using multiple lattice layers.\\n\",\n        \"rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\\n\",\n        \"    feature_configs=rtl_layer_feature_configs,\\n\",\n        \"    lattices='rtl_layer',\\n\",\n        \"    num_lattices=5,\\n\",\n        \"    lattice_rank=3,\\n\",\n        \"    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\\n\",\n        \"    output_initialization=[-2.0, 2.0],\\n\",\n        \"    random_seed=42)\\n\",\n        \"# A CalibratedLatticeEnsemble premade model constructed from the given\\n\",\n        \"# model config. Note that we do not have to specify the lattices by calling\\n\",\n        \"# a helper function (like before with random) because the RTL Layer will take\\n\",\n        \"# care of that for us.\\n\",\n        \"rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\\n\",\n        \"    rtl_layer_ensemble_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"yWdxZpS0JWag\"\n      },\n      \"source\": [\n        \"As before, we compile, fit, and evaluate our model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"HQdkkWwqJW8p\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"rtl_layer_ensemble_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True)],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"rtl_layer_ensemble_model.fit(\\n\",\n        \"    heart_train_xs,\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\\n\",\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(rtl_layer_ensemble_model.evaluate(heart_test_xs, heart_test_ys))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"A61VpAl8uOiT\"\n      },\n      \"source\": [\n        \"### Crystals Lattice Ensemble\\n\",\n        \"\\n\",\n        \"Premade also provides a heuristic feature arrangement algorithm, called [Crystals](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices). To use the Crystals algorithm, first we train a prefitting model that estimates pairwise feature interactions. We then arrange the final ensemble such that features with more non-linear interactions are in the same lattices.\\n\",\n        \"\\n\",\n        \"the Premade Library offers helper functions for constructing the prefitting model configuration and extracting the crystals structure. Note that the prefitting model does not need to be fully trained, so a few epochs should be enough.\\n\",\n        \"\\n\",\n        \"This example creates a calibrated lattice ensemble model with 5 lattice and 3 features per lattice.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"yT5eiknCu9sj\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# This is a calibrated lattice ensemble model: inputs are calibrated, then\\n\",\n        \"# combines non-linearly and averaged using multiple lattice layers.\\n\",\n        \"crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\\n\",\n        \"    feature_configs=heart_feature_configs,\\n\",\n        \"    lattices='crystals',\\n\",\n        \"    num_lattices=5,\\n\",\n        \"    lattice_rank=3,\\n\",\n        \"    # We initialize the output to [-2.0, 2.0] since we'll be using logits.\\n\",\n        \"    output_initialization=[-2.0, 2.0],\\n\",\n        \"    random_seed=42)\\n\",\n        \"# Now that we have our model config, we can construct a prefitting model config.\\n\",\n        \"prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(\\n\",\n        \"    crystals_ensemble_model_config)\\n\",\n        \"# A CalibratedLatticeEnsemble premade model constructed from the given\\n\",\n        \"# prefitting model config.\\n\",\n        \"prefitting_model = tfl.premade.CalibratedLatticeEnsemble(\\n\",\n        \"    prefitting_model_config)\\n\",\n        \"# We can compile and train our prefitting model as we like.\\n\",\n        \"prefitting_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"prefitting_model.fit(\\n\",\n        \"    heart_train_xs,\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=PREFITTING_NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\\n\",\n        \"# Now that we have our trained prefitting model, we can extract the crystals.\\n\",\n        \"tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,\\n\",\n        \"                                              prefitting_model_config,\\n\",\n        \"                                              prefitting_model)\\n\",\n        \"# A CalibratedLatticeEnsemble premade model constructed from the given\\n\",\n        \"# model config.\\n\",\n        \"crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\\n\",\n        \"    crystals_ensemble_model_config)\\n\",\n        \"# Let's plot our model.\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    crystals_ensemble_model, show_layer_names=False, rankdir='LR')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"PRLU1z-216h8\"\n      },\n      \"source\": [\n        \"As before, we compile, fit, and evaluate our model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"U73On3v91-Qq\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"crystals_ensemble_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True)],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE))\\n\",\n        \"crystals_ensemble_model.fit(\\n\",\n        \"    heart_train_xs,\\n\",\n        \"    heart_train_ys,\\n\",\n        \"    epochs=NUM_EPOCHS,\\n\",\n        \"    batch_size=BATCH_SIZE,\\n\",\n        \"    verbose=False)\\n\",\n        \"print('Test Set Evaluation...')\\n\",\n        \"print(crystals_ensemble_model.evaluate(heart_test_xs, heart_test_ys))\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"premade_models.ipynb\",\n      \"private_outputs\": true,\n      \"provenance\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "docs/tutorials/shape_constraints.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7765UFHoyGx6\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 The TensorFlow Authors.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"KsOkK8O69PyT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RKQpW0JqQQmY\"\n      },\n      \"source\": [\n        \"# Shape Constraints with Tensorflow Lattice\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"r61fkA2i9Y3_\"\n      },\n      \"source\": [\n        \"\\u003ctable class=\\\"tfo-notebook-buttons\\\" align=\\\"left\\\"\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://www.tensorflow.org/lattice/tutorials/shape_constraints\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/tf_logo_32px.png\\\" /\\u003eView on TensorFlow.org\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/colab_logo_32px.png\\\" /\\u003eRun in Google Colab\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\\\" /\\u003eView source on GitHub\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca href=\\\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/shape_constraints.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/download_logo_32px.png\\\" /\\u003eDownload notebook\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"\\u003c/table\\u003e\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2plcL3iTVjsp\"\n      },\n      \"source\": [\n        \"## Overview\\n\",\n        \"\\n\",\n        \"This tutorial is an overview of the constraints and regularizers provided by the TensorFlow Lattice (TFL) library. Here we use TFL premade models on synthetic datasets, but note that everything in this tutorial can also be done with models constructed from TFL Keras layers.\\n\",\n        \"\\n\",\n        \"Before proceeding, make sure your runtime has all required packages installed (as imported in the code cells below).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"x769lI12IZXB\"\n      },\n      \"source\": [\n        \"## Setup\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fbBVAR6UeRN5\"\n      },\n      \"source\": [\n        \"Installing TF Lattice package:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"bpXjJKpSd3j4\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@test {\\\"skip\\\": true}\\n\",\n        \"!pip install -U tensorflow tf-keras tensorflow-lattice pydot graphviz\\n\",\n        \"!pip install -U tensorflow_decision_forests\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"jSVl9SHTeSGX\"\n      },\n      \"source\": [\n        \"Importing required packages:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"id\": \"iY6awAl058TV\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"import tensorflow_lattice as tfl\\n\",\n        \"import tensorflow_decision_forests as tfdf\\n\",\n        \"\\n\",\n        \"from IPython.core.pylabtools import figsize\\n\",\n        \"import functools\\n\",\n        \"import logging\\n\",\n        \"import matplotlib\\n\",\n        \"from matplotlib import pyplot as plt\\n\",\n        \"import numpy as np\\n\",\n        \"import pandas as pd\\n\",\n        \"import sys\\n\",\n        \"import tempfile\\n\",\n        \"logging.disable(sys.maxsize)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"8dsfk2oNlakY\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Use Keras 2.\\n\",\n        \"version_fn = getattr(tf.keras, \\\"version\\\", None)\\n\",\n        \"if version_fn and version_fn().startswith(\\\"3.\\\"):\\n\",\n        \"  import tf_keras as keras\\n\",\n        \"else:\\n\",\n        \"  keras = tf.keras\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7TmBk_IGgJF0\"\n      },\n      \"source\": [\n        \"Default values used in this guide:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"kQHPyPsPUF92\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"NUM_EPOCHS = 1000\\n\",\n        \"BATCH_SIZE = 64\\n\",\n        \"LEARNING_RATE=0.01\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"FjR7D8Ag3z0d\"\n      },\n      \"source\": [\n        \"## Training Dataset for Ranking Restaurants\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"a1YetzbdFOij\"\n      },\n      \"source\": [\n        \"Imagine a simplified scenario where we want to determine whether or not users will click on a restaurant search result. The task is to predict the clickthrough rate (CTR) given input features:\\n\",\n        \"- Average rating (`avg_rating`): a numeric feature with values in the range [1,5].\\n\",\n        \"- Number of reviews (`num_reviews`): a numeric feature with values capped at 200, which we use as a measure of trendiness.\\n\",\n        \"- Dollar rating (`dollar_rating`): a categorical feature with string values in the set {\\\"D\\\", \\\"DD\\\", \\\"DDD\\\", \\\"DDDD\\\"}.\\n\",\n        \"\\n\",\n        \"Here we create a synthetic dataset where the true CTR is given by the formula:\\n\",\n        \"$$\\n\",\n        \"CTR = 1 / (1 + exp\\\\{\\\\mbox{b(dollar_rating)}-\\\\mbox{avg_rating}\\\\times log(\\\\mbox{num_reviews}) /4 \\\\})\\n\",\n        \"$$\\n\",\n        \"where $b(\\\\cdot)$ translates each `dollar_rating` to a baseline value:\\n\",\n        \"$$\\n\",\n        \"\\\\mbox{D}\\\\to 3,\\\\ \\\\mbox{DD}\\\\to 2,\\\\ \\\\mbox{DDD}\\\\to 4,\\\\ \\\\mbox{DDDD}\\\\to 4.5.\\n\",\n        \"$$\\n\",\n        \"\\n\",\n        \"This formula reflects typical user patterns. e.g. given everything else fixed, users prefer restaurants with higher star ratings, and \\\"\\\\\\\\$\\\\\\\\$\\\" restaurants will receive more clicks than \\\"\\\\\\\\$\\\", followed by \\\"\\\\\\\\$\\\\\\\\$\\\\\\\\$\\\" and \\\"\\\\\\\\$\\\\\\\\$\\\\\\\\$\\\\\\\\$\\\".\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"mKovnyv1jATw\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dollar_ratings_vocab = [\\\"D\\\", \\\"DD\\\", \\\"DDD\\\", \\\"DDDD\\\"]\\n\",\n        \"def click_through_rate(avg_ratings, num_reviews, dollar_ratings):\\n\",\n        \"  dollar_rating_baseline = {\\\"D\\\": 3, \\\"DD\\\": 2, \\\"DDD\\\": 4, \\\"DDDD\\\": 4.5}\\n\",\n        \"  return 1 / (1 + np.exp(\\n\",\n        \"      np.array([dollar_rating_baseline[d] for d in dollar_ratings]) -\\n\",\n        \"      avg_ratings * np.log1p(num_reviews) / 4))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"BPlgRdt6jAbP\"\n      },\n      \"source\": [\n        \"Let's take a look at the contour plots of this CTR function.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"KC5qX_XKmc7g\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def color_bar():\\n\",\n        \"  bar = matplotlib.cm.ScalarMappable(\\n\",\n        \"      norm=matplotlib.colors.Normalize(0, 1, True),\\n\",\n        \"      cmap=\\\"viridis\\\",\\n\",\n        \"  )\\n\",\n        \"  bar.set_array([0, 1])\\n\",\n        \"  return bar\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def plot_fns(fns, res=25):\\n\",\n        \"  \\\"\\\"\\\"Generates contour plots for a list of (name, fn) functions.\\\"\\\"\\\"\\n\",\n        \"  num_reviews, avg_ratings = np.meshgrid(\\n\",\n        \"      np.linspace(0, 200, num=res),\\n\",\n        \"      np.linspace(1, 5, num=res),\\n\",\n        \"  )\\n\",\n        \"  figsize(13, 3.5 * len(fns))\\n\",\n        \"  fig, axes = plt.subplots(\\n\",\n        \"      len(fns), len(dollar_ratings_vocab), sharey=True, layout=\\\"constrained\\\"\\n\",\n        \"  )\\n\",\n        \"  axes = axes.flatten()\\n\",\n        \"  axes_index = 0\\n\",\n        \"  for fn_name, fn in fns:\\n\",\n        \"    for dollar_rating_split in dollar_ratings_vocab:\\n\",\n        \"      dollar_ratings = np.repeat(dollar_rating_split, res**2)\\n\",\n        \"      values = fn(avg_ratings.flatten(), num_reviews.flatten(), dollar_ratings)\\n\",\n        \"      title = \\\"{}: dollar_rating={}\\\".format(fn_name, dollar_rating_split)\\n\",\n        \"      subplot = axes[axes_index]\\n\",\n        \"      axes_index += 1\\n\",\n        \"      subplot.contourf(\\n\",\n        \"          avg_ratings,\\n\",\n        \"          num_reviews,\\n\",\n        \"          np.reshape(values, (res, res)),\\n\",\n        \"          vmin=0,\\n\",\n        \"          vmax=1,\\n\",\n        \"      )\\n\",\n        \"      subplot.title.set_text(title)\\n\",\n        \"      subplot.set(xlabel=\\\"Average Rating\\\")\\n\",\n        \"      subplot.set(ylabel=\\\"Number of Reviews\\\")\\n\",\n        \"      subplot.set(xlim=(1, 5))\\n\",\n        \"\\n\",\n        \"  if len(fns) \\u003c= 2:\\n\",\n        \"    cax = fig.add_axes([\\n\",\n        \"        axes[-1].get_position().x1 + 0.11,\\n\",\n        \"        axes[-1].get_position().y0,\\n\",\n        \"        0.02,\\n\",\n        \"        0.8,\\n\",\n        \"    ])\\n\",\n        \"    _ = fig.colorbar(color_bar(), cax=cax)\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"plot_fns([(\\\"CTR\\\", click_through_rate)])\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Ol91olp3muNN\"\n      },\n      \"source\": [\n        \"### Preparing Data\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"H8BOshZS9xwn\"\n      },\n      \"source\": [\n        \"We now need to create our synthetic datasets. We start by generating a simulated dataset of restaurants and their features.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"MhqcOPdTT_wj\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def sample_restaurants(n):\\n\",\n        \"  avg_ratings = np.random.uniform(1.0, 5.0, n)\\n\",\n        \"  num_reviews = np.round(np.exp(np.random.uniform(0.0, np.log(200), n)))\\n\",\n        \"  dollar_ratings = np.random.choice(dollar_ratings_vocab, n)\\n\",\n        \"  ctr_labels = click_through_rate(avg_ratings, num_reviews, dollar_ratings)\\n\",\n        \"  return avg_ratings, num_reviews, dollar_ratings, ctr_labels\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"np.random.seed(42)\\n\",\n        \"avg_ratings, num_reviews, dollar_ratings, ctr_labels = sample_restaurants(2000)\\n\",\n        \"\\n\",\n        \"figsize(5, 5)\\n\",\n        \"fig, axs = plt.subplots(1, 1, sharey=False, layout=\\\"constrained\\\")\\n\",\n        \"\\n\",\n        \"for rating, marker in [(\\\"D\\\", \\\"o\\\"), (\\\"DD\\\", \\\"^\\\"), (\\\"DDD\\\", \\\"+\\\"), (\\\"DDDD\\\", \\\"x\\\")]:\\n\",\n        \"  plt.scatter(\\n\",\n        \"      x=avg_ratings[np.where(dollar_ratings == rating)],\\n\",\n        \"      y=num_reviews[np.where(dollar_ratings == rating)],\\n\",\n        \"      c=ctr_labels[np.where(dollar_ratings == rating)],\\n\",\n        \"      vmin=0,\\n\",\n        \"      vmax=1,\\n\",\n        \"      marker=marker,\\n\",\n        \"      label=rating)\\n\",\n        \"plt.xlabel(\\\"Average Rating\\\")\\n\",\n        \"plt.ylabel(\\\"Number of Reviews\\\")\\n\",\n        \"plt.legend()\\n\",\n        \"plt.xlim((1, 5))\\n\",\n        \"plt.title(\\\"Distribution of restaurants\\\")\\n\",\n        \"_ = fig.colorbar(color_bar(), cax=fig.add_axes([1.05, 0.1, 0.05, 0.85]))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"tRetsfLv_JSR\"\n      },\n      \"source\": [\n        \"Let's produce the training, validation and testing datasets. When a restaurant is viewed in the search results, we can record user's engagement (click or no click) as a sample point.\\n\",\n        \"\\n\",\n        \"In practice, users often do not go through all search results. This means that users will likely only see restaurants already considered \\\"good\\\" by the current ranking model in use. As a result, \\\"good\\\" restaurants are more frequently impressed and over-represented in the training datasets. When using more features, the training dataset can have large gaps in \\\"bad\\\" parts of the feature space.\\n\",\n        \"\\n\",\n        \"When the model is used for ranking, it is often evaluated on all relevant results with a more uniform distribution that is not well-represented by the training dataset. A flexible and complicated model might fail in this case due to overfitting the over-represented data points and thus lack generalizability. We handle this issue by applying domain knowledge to add *shape constraints* that guide the model to make reasonable predictions when it cannot pick them up from the training dataset.\\n\",\n        \"\\n\",\n        \"In this example, the training dataset mostly consists of user interactions with good and popular restaurants. The testing dataset has a uniform distribution to simulate the evaluation setting discussed above. Note that such testing dataset will not be available in a real problem setting.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"jS6WOtXQ8jwX\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def sample_dataset(n, testing_set):\\n\",\n        \"  (avg_ratings, num_reviews, dollar_ratings, ctr_labels) = sample_restaurants(n)\\n\",\n        \"  if testing_set:\\n\",\n        \"    # Testing has a more uniform distribution over all restaurants.\\n\",\n        \"    num_views = np.random.poisson(lam=3, size=n)\\n\",\n        \"  else:\\n\",\n        \"    # Training/validation datasets have more views on popular restaurants.\\n\",\n        \"    num_views = np.random.poisson(lam=ctr_labels * num_reviews / 50.0, size=n)\\n\",\n        \"\\n\",\n        \"  return pd.DataFrame({\\n\",\n        \"      \\\"avg_rating\\\": np.repeat(avg_ratings, num_views),\\n\",\n        \"      \\\"num_reviews\\\": np.repeat(num_reviews, num_views),\\n\",\n        \"      \\\"dollar_rating\\\": np.repeat(dollar_ratings, num_views),\\n\",\n        \"      \\\"clicked\\\": np.random.binomial(n=1, p=np.repeat(ctr_labels, num_views)),\\n\",\n        \"  })\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"# Generate datasets.\\n\",\n        \"np.random.seed(42)\\n\",\n        \"data_train = sample_dataset(500, testing_set=False)\\n\",\n        \"data_val = sample_dataset(500, testing_set=False)\\n\",\n        \"data_test = sample_dataset(500, testing_set=True)\\n\",\n        \"\\n\",\n        \"ds_train = tfdf.keras.pd_dataframe_to_tf_dataset(\\n\",\n        \"    data_train, label=\\\"clicked\\\", batch_size=BATCH_SIZE\\n\",\n        \")\\n\",\n        \"ds_val = tfdf.keras.pd_dataframe_to_tf_dataset(\\n\",\n        \"    data_val, label=\\\"clicked\\\", batch_size=BATCH_SIZE\\n\",\n        \")\\n\",\n        \"ds_test = tfdf.keras.pd_dataframe_to_tf_dataset(\\n\",\n        \"    data_test, label=\\\"clicked\\\", batch_size=BATCH_SIZE\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"# feature_analysis_data is used to find quantiles of featurse.\\n\",\n        \"feature_analysis_data = data_train.copy()\\n\",\n        \"feature_analysis_data[\\\"dollar_rating\\\"] = feature_analysis_data[\\n\",\n        \"    \\\"dollar_rating\\\"\\n\",\n        \"].map({v: i for i, v in enumerate(dollar_ratings_vocab)})\\n\",\n        \"feature_analysis_data = dict(feature_analysis_data)\\n\",\n        \"\\n\",\n        \"# Plotting dataset densities.\\n\",\n        \"figsize(12, 5)\\n\",\n        \"fig, axs = plt.subplots(1, 2, sharey=False, tight_layout=False)\\n\",\n        \"for ax, data, title in [\\n\",\n        \"    (axs[0], data_train, \\\"training\\\"),\\n\",\n        \"    (axs[1], data_test, \\\"testing\\\"),\\n\",\n        \"]:\\n\",\n        \"  _, _, _, density = ax.hist2d(\\n\",\n        \"      x=data[\\\"avg_rating\\\"],\\n\",\n        \"      y=data[\\\"num_reviews\\\"],\\n\",\n        \"      bins=(np.linspace(1, 5, num=21), np.linspace(0, 200, num=21)),\\n\",\n        \"      cmap=\\\"Blues\\\",\\n\",\n        \"  )\\n\",\n        \"  ax.set(xlim=(1, 5))\\n\",\n        \"  ax.set(ylim=(0, 200))\\n\",\n        \"  ax.set(xlabel=\\\"Average Rating\\\")\\n\",\n        \"  ax.set(ylabel=\\\"Number of Reviews\\\")\\n\",\n        \"  ax.title.set_text(\\\"Density of {} examples\\\".format(title))\\n\",\n        \"  _ = fig.colorbar(density, ax=ax)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"qoTrw3FZqvPK\"\n      },\n      \"source\": [\n        \"## Fitting Gradient Boosted Trees\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ZklNowexE3wB\"\n      },\n      \"source\": [\n        \"We first create a few auxillary functions for plotting and calculating validation and test metrics.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"3BqGqScQzlYf\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def pred_fn(model, from_logits, avg_ratings, num_reviews, dollar_rating):\\n\",\n        \"  preds = model.predict(\\n\",\n        \"      tf.data.Dataset.from_tensor_slices({\\n\",\n        \"          \\\"avg_rating\\\": avg_ratings,\\n\",\n        \"          \\\"num_reviews\\\": num_reviews,\\n\",\n        \"          \\\"dollar_rating\\\": dollar_rating,\\n\",\n        \"      }).batch(1),\\n\",\n        \"      verbose=0,\\n\",\n        \"  )\\n\",\n        \"  if from_logits:\\n\",\n        \"    preds = tf.math.sigmoid(preds)\\n\",\n        \"  return preds\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def analyze_model(models, from_logits=False, print_metrics=True):\\n\",\n        \"  pred_fns = []\\n\",\n        \"  for model, name in models:\\n\",\n        \"    if print_metrics:\\n\",\n        \"      metric = model.evaluate(ds_val, return_dict=True, verbose=0)\\n\",\n        \"      print(\\\"Validation AUC: {}\\\".format(metric[\\\"auc\\\"]))\\n\",\n        \"      metric = model.evaluate(ds_test, return_dict=True, verbose=0)\\n\",\n        \"      print(\\\"Testing AUC: {}\\\".format(metric[\\\"auc\\\"]))\\n\",\n        \"\\n\",\n        \"    pred_fns.append(\\n\",\n        \"        (\\\"{} pCTR\\\".format(name), functools.partial(pred_fn, model, from_logits))\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"  pred_fns.append((\\\"CTR\\\", click_through_rate))\\n\",\n        \"  plot_fns(pred_fns)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"JVef4f8yUUbs\"\n      },\n      \"source\": [\n        \"We can fit TensorFlow gradient boosted decision trees on the dataset:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"DnPYlRAo2mnQ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"gbt_model = tfdf.keras.GradientBoostedTreesModel(\\n\",\n        \"    features=[\\n\",\n        \"        tfdf.keras.FeatureUsage(name=\\\"num_reviews\\\"),\\n\",\n        \"        tfdf.keras.FeatureUsage(name=\\\"avg_rating\\\"),\\n\",\n        \"        tfdf.keras.FeatureUsage(name=\\\"dollar_rating\\\"),\\n\",\n        \"    ],\\n\",\n        \"    exclude_non_specified_features=True,\\n\",\n        \"    num_threads=1,\\n\",\n        \"    num_trees=32,\\n\",\n        \"    max_depth=6,\\n\",\n        \"    min_examples=10,\\n\",\n        \"    growing_strategy=\\\"BEST_FIRST_GLOBAL\\\",\\n\",\n        \"    random_seed=42,\\n\",\n        \"    temp_directory=tempfile.mkdtemp(),\\n\",\n        \")\\n\",\n        \"gbt_model.compile(metrics=[keras.metrics.AUC(name=\\\"auc\\\")])\\n\",\n        \"gbt_model.fit(ds_train, validation_data=ds_val, verbose=0)\\n\",\n        \"analyze_model([(gbt_model, \\\"GBT\\\")])\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"nYZtd6YvsNdn\"\n      },\n      \"source\": [\n        \"Even though the model has captured the general shape of the true CTR and has decent validation metrics, it has counter-intuitive behavior in several parts of the input space: the estimated CTR decreases as the average rating or number of reviews increase. This is due to a lack of sample points in areas not well-covered by the training dataset. The model simply has no way to deduce the correct behaviour solely from the data.\\n\",\n        \"\\n\",\n        \"To solve this issue, we enforce the shape constraint that the model must output values monotonically increasing with respect to both the average rating and the number of reviews. We will later see how to implement this in TFL.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Uf7WqGooFiEp\"\n      },\n      \"source\": [\n        \"## Fitting a DNN\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"_s2aT3x0E_tF\"\n      },\n      \"source\": [\n        \"We can repeat the same steps with a DNN classifier. We can observe a similar pattern: not having enough sample points with small number of reviews results in nonsensical extrapolation.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"WKZzCY-UkZX-\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"inputs = {\\n\",\n        \"    \\\"num_reviews\\\": keras.Input(shape=(1,), dtype=tf.float32),\\n\",\n        \"    \\\"avg_rating\\\": keras.Input(shape=(1), dtype=tf.float32),\\n\",\n        \"    \\\"dollar_rating\\\": keras.Input(shape=(1), dtype=tf.string),\\n\",\n        \"}\\n\",\n        \"inputs_flat = keras.layers.Concatenate()([\\n\",\n        \"    inputs[\\\"num_reviews\\\"],\\n\",\n        \"    inputs[\\\"avg_rating\\\"],\\n\",\n        \"    keras.layers.StringLookup(\\n\",\n        \"        vocabulary=dollar_ratings_vocab,\\n\",\n        \"        num_oov_indices=0,\\n\",\n        \"        output_mode=\\\"one_hot\\\",\\n\",\n        \"    )(inputs[\\\"dollar_rating\\\"]),\\n\",\n        \"])\\n\",\n        \"dense_layers = keras.Sequential(\\n\",\n        \"    [\\n\",\n        \"        keras.layers.Dense(16, activation=\\\"relu\\\"),\\n\",\n        \"        keras.layers.Dense(16, activation=\\\"relu\\\"),\\n\",\n        \"        keras.layers.Dense(1, activation=None),\\n\",\n        \"    ],\\n\",\n        \"    name=\\\"dense_layers\\\",\\n\",\n        \")\\n\",\n        \"dnn_model = keras.Model(inputs=inputs, outputs=dense_layers(inputs_flat))\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    dnn_model, expand_nested=True, show_layer_names=False, rankdir=\\\"LR\\\"\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"6zFqu6wf1I30\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dnn_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True, name=\\\"auc\\\")],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE),\\n\",\n        \")\\n\",\n        \"dnn_model.fit(ds_train, epochs=200, verbose=0)\\n\",\n        \"analyze_model([(dnn_model, \\\"DNN\\\")], from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0Avkw-okw7JL\"\n      },\n      \"source\": [\n        \"## Shape Constraints\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"3ExyethCFBrP\"\n      },\n      \"source\": [\n        \"TensorFlow Lattice (TFL) is focused on enforcing shape constraints to safeguard model behavior beyond the training data. These shape constraints are applied to TFL Keras layers. Their details can be found in [our JMLR paper](http://jmlr.org/papers/volume17/15-243/15-243.pdf).\\n\",\n        \"\\n\",\n        \"In this tutorial we use TF premade models to cover various shape constraints, but note that all these steps can be done with models created from TFL Keras layers.\\n\",\n        \"\\n\",\n        \"Using TFL premade models also requires:\\n\",\n        \"- a *model config*: defining the model architecture and per-feature shape constraints and regularizers.\\n\",\n        \"- a *feature analysis dataset*: a dataset used for TFL initialization (feature quantile calcuation).\\n\",\n        \"\\n\",\n        \"For a more thorough description, please refer to the premade models or the API docs.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"anyCM4sCpOSo\"\n      },\n      \"source\": [\n        \"### Monotonicity\\n\",\n        \"We first address the monotonicity concerns by adding monotonicity shape constraints to the continuous features. We use a calibrated lattice model with added output calibration: each feature is calibrated using categorical or piecewise-linear calibrators, then fed into a lattice model, followed by an output piecewise-linear calibrator.\\n\",\n        \"\\n\",\n        \"To instruct TFL to enforce shape constraints, we specify the constraints in the *feature configs*. The following code shows how we can require the output to be monotonically increasing with respect to both `num_reviews` and `avg_rating` by setting `monotonicity=\\\"increasing\\\"`.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"hFlkZs5RgFcP\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"num_reviews\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"avg_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"dollar_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=4,\\n\",\n        \"            vocabulary_list=dollar_ratings_vocab,\\n\",\n        \"            num_buckets=len(dollar_ratings_vocab),\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=5),\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"GOlzuyQsGre5\"\n      },\n      \"source\": [\n        \"We now use the `feature_analysis_data` to find and set the quantile values for the input features. These values can be pre-calculated and set explicitly in the feature config depending on the training pipeline.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"f-bTmfBnghuX\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"feature_analysis_data = data_train.copy()\\n\",\n        \"feature_analysis_data[\\\"dollar_rating\\\"] = feature_analysis_data[\\n\",\n        \"    \\\"dollar_rating\\\"\\n\",\n        \"].map({v: i for i, v in enumerate(dollar_ratings_vocab)})\\n\",\n        \"feature_analysis_data = dict(feature_analysis_data)\\n\",\n        \"\\n\",\n        \"feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs, features=feature_analysis_data\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"FCm1lOjmwur_\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"inputs = {\\n\",\n        \"    \\\"num_reviews\\\": keras.Input(shape=(1,), dtype=tf.float32),\\n\",\n        \"    \\\"avg_rating\\\": keras.Input(shape=(1), dtype=tf.float32),\\n\",\n        \"    \\\"dollar_rating\\\": keras.Input(shape=(1), dtype=tf.string),\\n\",\n        \"}\\n\",\n        \"ordered_inputs = [\\n\",\n        \"    inputs[\\\"num_reviews\\\"],\\n\",\n        \"    inputs[\\\"avg_rating\\\"],\\n\",\n        \"    keras.layers.StringLookup(\\n\",\n        \"        vocabulary=dollar_ratings_vocab,\\n\",\n        \"        num_oov_indices=0,\\n\",\n        \"        output_mode=\\\"int\\\",\\n\",\n        \"    )(inputs[\\\"dollar_rating\\\"]),\\n\",\n        \"]\\n\",\n        \"outputs = tfl.premade.CalibratedLattice(\\n\",\n        \"    model_config=model_config, name=\\\"CalibratedLattice\\\"\\n\",\n        \")(ordered_inputs)\\n\",\n        \"tfl_model_0 = keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    tfl_model_0, expand_nested=True, show_layer_names=False, rankdir=\\\"LR\\\"\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ubNRBCWW5wQ9\"\n      },\n      \"source\": [\n        \"Using a `CalibratedLatticeConfig` creates a premade classifier that first applies a *calibrator* to each input (a piece-wise linear function for numeric features) followed by a *lattice* layer to non-linearly fuse the calibrated features. We have also enabled output piece-wise linear calibration.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Am1OwtzzU7no\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"tfl_model_0.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True, name=\\\"auc\\\")],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE),\\n\",\n        \")\\n\",\n        \"tfl_model_0.fit(ds_train, epochs=100, verbose=0)\\n\",\n        \"analyze_model([(tfl_model_0, \\\"TFL0\\\")], from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7vZ5fShXs504\"\n      },\n      \"source\": [\n        \"With the constraints added, the estimated CTR will always increase as the average rating increases or the number of reviews increases. This is done by making sure that the calibrators and the lattice are monotonic.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"pSUd6aFlpYz4\"\n      },\n      \"source\": [\n        \"### Partial Monotonicity for Categorical Calibration\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"CnPiqf4rq6kJ\"\n      },\n      \"source\": [\n        \"To use constraints on the third feature, `dollar_rating`, we should recall that categorical features require a slightly different treatment in TFL. Here we enforce the partial monotonicity constraint that outputs for \\\"DD\\\" restaurants should be larger than \\\"D\\\" restaurants when all other inputs are fixed. This is done using the `monotonicity` setting in the feature config. We also need to use `tfl.premade_lib.set_categorical_monotonicities` to convert the constrains specified in string values into the numerical format understood by the library.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"FH2ItfsTsE3S\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"num_reviews\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_convexity=\\\"concave\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"avg_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"dollar_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=4,\\n\",\n        \"            vocabulary_list=dollar_ratings_vocab,\\n\",\n        \"            num_buckets=len(dollar_ratings_vocab),\\n\",\n        \"            monotonicity=[(\\\"D\\\", \\\"DD\\\")],\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=5),\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\\n\",\n        \"\\n\",\n        \"outputs = tfl.premade.CalibratedLattice(\\n\",\n        \"    model_config=model_config, name=\\\"CalibratedLattice\\\"\\n\",\n        \")(ordered_inputs)\\n\",\n        \"tfl_model_1 = keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"tfl_model_1.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True, name=\\\"auc\\\")],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE),\\n\",\n        \")\\n\",\n        \"tfl_model_1.fit(ds_train, epochs=100, verbose=0)\\n\",\n        \"analyze_model([(tfl_model_1, \\\"TFL1\\\")], from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"gdIzhYL79_Pp\"\n      },\n      \"source\": [\n        \"Here we also plot the predicted CTR of this model conditioned on `dollar_rating`. Notice that all the constraints we required are fulfilled in each of the slices.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"J6CP2Ovapiu3\"\n      },\n      \"source\": [\n        \"### 2D Shape Constraint: Trust\\n\",\n        \"A 5-star rating for a restaurant with only one or two reviews is likely an unreliable rating (the restaurant might not actually be good), whereas a 4-star rating for a restaurant with hundreds of reviews is much more reliable (the restaurant is likely good in this case). We can see that the number of reviews of a restaurant affects how much trust we place in its average rating.\\n\",\n        \"\\n\",\n        \"We can exercise TFL trust constraints to inform the model that the larger (or smaller) value of one feature indicates more reliance or trust of another feature. This is done by setting `reflects_trust_in` configuration in the feature config.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"OA14j0erm6TJ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"num_reviews\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"            # Larger num_reviews indicating more trust in avg_rating.\\n\",\n        \"            reflects_trust_in=[\\n\",\n        \"                tfl.configs.TrustConfig(\\n\",\n        \"                    feature_name=\\\"avg_rating\\\", trust_type=\\\"edgeworth\\\"\\n\",\n        \"                ),\\n\",\n        \"            ],\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"avg_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"dollar_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=4,\\n\",\n        \"            vocabulary_list=dollar_ratings_vocab,\\n\",\n        \"            num_buckets=len(dollar_ratings_vocab),\\n\",\n        \"            monotonicity=[(\\\"D\\\", \\\"DD\\\")],\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=5),\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\\n\",\n        \"\\n\",\n        \"outputs = tfl.premade.CalibratedLattice(\\n\",\n        \"    model_config=model_config, name=\\\"CalibratedLattice\\\"\\n\",\n        \")(ordered_inputs)\\n\",\n        \"tfl_model_2 = keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"tfl_model_2.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True, name=\\\"auc\\\")],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE),\\n\",\n        \")\\n\",\n        \"tfl_model_2.fit(ds_train, epochs=100, verbose=0)\\n\",\n        \"analyze_model([(tfl_model_2, \\\"TFL2\\\")], from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"puvP9X8XxyRV\"\n      },\n      \"source\": [\n        \"The following plot presents the trained lattice function. Due to the trust constraint, we expect that larger values of calibrated `num_reviews` would force higher slope with respect to calibrated `avg_rating`, resulting in a more significant move in the lattice output.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"both\",\n        \"id\": \"RounEQebxxnA\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"lattice_params = tfl_model_2.layers[-1].layers[-2].weights[0].numpy()\\n\",\n        \"lat_mesh_x, lat_mesh_y = np.meshgrid(\\n\",\n        \"    np.linspace(0, 1, num=3),\\n\",\n        \"    np.linspace(0, 1, num=3),\\n\",\n        \")\\n\",\n        \"lat_mesh_z = np.reshape(np.asarray(lattice_params[0::3]), (3, 3))\\n\",\n        \"\\n\",\n        \"figure = plt.figure(figsize=(6, 6))\\n\",\n        \"axes = figure.add_subplot(projection=\\\"3d\\\")\\n\",\n        \"axes.plot_wireframe(lat_mesh_x, lat_mesh_y, lat_mesh_z, color=\\\"dodgerblue\\\")\\n\",\n        \"plt.legend([\\\"Lattice Lookup\\\"])\\n\",\n        \"plt.title(\\\"Trust\\\")\\n\",\n        \"plt.xlabel(\\\"Calibrated avg_rating\\\")\\n\",\n        \"plt.ylabel(\\\"Calibrated num_reviews\\\")\\n\",\n        \"plt.show()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RfniRZCHIvfK\"\n      },\n      \"source\": [\n        \"### Diminishing Returns\\n\",\n        \"[Diminishing returns](https://en.wikipedia.org/wiki/Diminishing_returns) means that the marginal gain of increasing a certain feature value will decrease as we increase the value. In our case we expect that the `num_reviews` feature follows this pattern, so we can configure its calibrator accordingly. Notice that we can decompose diminishing returns into two sufficient conditions:\\n\",\n        \"\\n\",\n        \"- the calibrator is monotonicially increasing, and\\n\",\n        \"- the calibrator is concave (setting `pwl_calibration_convexity=\\\"concave\\\"`).\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"XQrM9BskY-wx\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"num_reviews\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_convexity=\\\"concave\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"            reflects_trust_in=[\\n\",\n        \"                tfl.configs.TrustConfig(\\n\",\n        \"                    feature_name=\\\"avg_rating\\\", trust_type=\\\"edgeworth\\\"\\n\",\n        \"                ),\\n\",\n        \"            ],\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"avg_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"dollar_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=4,\\n\",\n        \"            vocabulary_list=dollar_ratings_vocab,\\n\",\n        \"            num_buckets=len(dollar_ratings_vocab),\\n\",\n        \"            monotonicity=[(\\\"D\\\", \\\"DD\\\")],\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=5),\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\\n\",\n        \"\\n\",\n        \"outputs = tfl.premade.CalibratedLattice(\\n\",\n        \"    model_config=model_config, name=\\\"CalibratedLattice\\\"\\n\",\n        \")(ordered_inputs)\\n\",\n        \"tfl_model_3 = keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"tfl_model_3.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True, name=\\\"auc\\\")],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE),\\n\",\n        \")\\n\",\n        \"tfl_model_3.fit(\\n\",\n        \"    ds_train,\\n\",\n        \"    epochs=100,\\n\",\n        \"    verbose=0\\n\",\n        \")\\n\",\n        \"analyze_model([(tfl_model_3, \\\"TFL3\\\")], from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"LSmzHkPUo9u5\"\n      },\n      \"source\": [\n        \"Notice how the testing metric improves by adding the concavity constraint. The prediction plot also better resembles the ground truth.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"SKe3UHX6pUjw\"\n      },\n      \"source\": [\n        \"### Smoothing Calibrators\\n\",\n        \"We notice in the prediction curves above that even though the output is monotonic in specified features, the changes in the slopes are abrupt and hard to interpret. That suggests we might want to consider smoothing this calibrator using a regularizer setup in the `regularizer_configs`.\\n\",\n        \"\\n\",\n        \"Here we apply a `hessian` regularizer to make the calibration more linear. You can also use the `laplacian` regularizer to flatten the calibrator and the `wrinkle` regularizer to reduce changes in the curvature.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"CxcCNxhkqC7u\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"num_reviews\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_convexity=\\\"concave\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"            regularizer_configs=[\\n\",\n        \"                tfl.configs.RegularizerConfig(name=\\\"calib_hessian\\\", l2=0.5),\\n\",\n        \"            ],\\n\",\n        \"            reflects_trust_in=[\\n\",\n        \"                tfl.configs.TrustConfig(\\n\",\n        \"                    feature_name=\\\"avg_rating\\\", trust_type=\\\"edgeworth\\\"\\n\",\n        \"                ),\\n\",\n        \"            ],\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"avg_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            monotonicity=\\\"increasing\\\",\\n\",\n        \"            pwl_calibration_num_keypoints=32,\\n\",\n        \"            regularizer_configs=[\\n\",\n        \"                tfl.configs.RegularizerConfig(name=\\\"calib_hessian\\\", l2=0.5),\\n\",\n        \"            ],\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name=\\\"dollar_rating\\\",\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=4,\\n\",\n        \"            vocabulary_list=dollar_ratings_vocab,\\n\",\n        \"            num_buckets=len(dollar_ratings_vocab),\\n\",\n        \"            monotonicity=[(\\\"D\\\", \\\"DD\\\")],\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=5),\\n\",\n        \"    regularizer_configs=[\\n\",\n        \"        tfl.configs.RegularizerConfig(name=\\\"calib_hessian\\\", l2=0.1),\\n\",\n        \"    ],\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_categorical_monotonicities(model_config.feature_configs)\\n\",\n        \"\\n\",\n        \"outputs = tfl.premade.CalibratedLattice(\\n\",\n        \"    model_config=model_config, name=\\\"CalibratedLattice\\\"\\n\",\n        \")(ordered_inputs)\\n\",\n        \"tfl_model_4 = keras.Model(inputs=inputs, outputs=outputs)\\n\",\n        \"tfl_model_4.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.AUC(from_logits=True, name=\\\"auc\\\")],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATE),\\n\",\n        \")\\n\",\n        \"tfl_model_4.fit(ds_train, epochs=100, verbose=0)\\n\",\n        \"analyze_model([(tfl_model_4, \\\"TFL4\\\")], from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"HHpp4goLvuPi\"\n      },\n      \"source\": [\n        \"The calibrators are now smooth, and the overall estimated CTR better matches the ground truth. This is reflected both in the testing metric and in the contour plots.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"TLOGDrYY0hH7\"\n      },\n      \"source\": [\n        \"Here you can see the results of each step as we added domain-specific constraints and regularizers to the model.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"nUEuihX815ix\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"analyze_model(\\n\",\n        \"    [\\n\",\n        \"        (tfl_model_0, \\\"TFL0\\\"),\\n\",\n        \"        (tfl_model_1, \\\"TFL1\\\"),\\n\",\n        \"        (tfl_model_2, \\\"TFL2\\\"),\\n\",\n        \"        (tfl_model_3, \\\"TFL3\\\"),\\n\",\n        \"        (tfl_model_4, \\\"TFL4\\\"),\\n\",\n        \"    ],\\n\",\n        \"    from_logits=True,\\n\",\n        \"    print_metrics=False,\\n\",\n        \")\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"name\": \"shape_constraints.ipynb\",\n      \"private_outputs\": true,\n      \"provenance\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "docs/tutorials/shape_constraints_for_ethics.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"R2AxpObRncMd\"\n      },\n      \"source\": [\n        \"***Copyright 2020 The TensorFlow Authors.***\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"gQ5Kfh1YnkFS\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n\",\n        \"# you may not use this file except in compliance with the License.\\n\",\n        \"# You may obtain a copy of the License at\\n\",\n        \"#\\n\",\n        \"# https://www.apache.org/licenses/LICENSE-2.0\\n\",\n        \"#\\n\",\n        \"# Unless required by applicable law or agreed to in writing, software\\n\",\n        \"# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n\",\n        \"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n\",\n        \"# See the License for the specific language governing permissions and\\n\",\n        \"# limitations under the License.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"uc0VwsT5nvQi\"\n      },\n      \"source\": [\n        \"# Shape Constraints for Ethics with Tensorflow Lattice\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"gqJQZdvfn32j\"\n      },\n      \"source\": [\n        \"\\u003ctable class=\\\"tfo-notebook-buttons\\\" align=\\\"left\\\"\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://www.tensorflow.org/lattice/tutorials/shape_constraints_for_ethics\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/tf_logo_32px.png\\\" /\\u003eView on TensorFlow.org\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints_for_ethics.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/colab_logo_32px.png\\\" /\\u003eRun in Google Colab\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca target=\\\"_blank\\\" href=\\\"https://github.com/tensorflow/lattice/blob/master/docs/tutorials/shape_constraints_for_ethics.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\\\" /\\u003eView source on GitHub\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"  \\u003ctd\\u003e\\n\",\n        \"    \\u003ca href=\\\"https://storage.googleapis.com/tensorflow_docs/lattice/docs/tutorials/shape_constraints_for_ethics.ipynb\\\"\\u003e\\u003cimg src=\\\"https://www.tensorflow.org/images/download_logo_32px.png\\\" /\\u003eDownload notebook\\u003c/a\\u003e\\n\",\n        \"  \\u003c/td\\u003e\\n\",\n        \"\\u003c/table\\u003e\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"YFZbuZMAoBny\"\n      },\n      \"source\": [\n        \"## Overview\\n\",\n        \"\\n\",\n        \"This tutorial demonstrates how the TensorFlow Lattice (TFL) library can be used\\n\",\n        \"to train models that behave *responsibly*, and do not violate certain\\n\",\n        \"assumptions that are *ethical* or *fair*. In particular, we will focus on using monotonicity constraints to avoid *unfair penalization* of certain attributes. This tutorial includes demonstrations\\n\",\n        \"of the experiments from the paper\\n\",\n        \"[*Deontological Ethics By Monotonicity Shape Constraints*](https://arxiv.org/abs/2001.11990)\\n\",\n        \"by Serena Wang and Maya Gupta, published at\\n\",\n        \"[AISTATS 2020](https://www.aistats.org/).\\n\",\n        \"\\n\",\n        \"We will use TFL premade models on public datasets, but note that\\n\",\n        \"everything in this tutorial can also be done with models constructed from TFL\\n\",\n        \"Keras layers.\\n\",\n        \"\\n\",\n        \"Before proceeding, make sure your runtime has all required packages installed\\n\",\n        \"(as imported in the code cells below).\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"o4L76T-NpgCS\"\n      },\n      \"source\": [\n        \"## Setup\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"6FvmHcqbpkL7\"\n      },\n      \"source\": [\n        \"Installing TF Lattice package:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"f91yvUt_peYs\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@test {\\\"skip\\\": true}\\n\",\n        \"!pip install -U tensorflow tf-keras tensorflow-lattice seaborn pydot graphviz\\n\",\n        \"!pip install -U tensorflow_decision_forests\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"6TDoQsvSpmfx\"\n      },\n      \"source\": [\n        \"Importing required packages:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"KGt0pm0b1O5X\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import tensorflow as tf\\n\",\n        \"import tensorflow_lattice as tfl\\n\",\n        \"import tensorflow_decision_forests as tfdf\\n\",\n        \"\\n\",\n        \"import logging\\n\",\n        \"import matplotlib.pyplot as plt\\n\",\n        \"import numpy as np\\n\",\n        \"import os\\n\",\n        \"import pandas as pd\\n\",\n        \"import seaborn as sns\\n\",\n        \"from sklearn.model_selection import train_test_split\\n\",\n        \"import sys\\n\",\n        \"import tempfile\\n\",\n        \"logging.disable(sys.maxsize)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"csVitiM20zAY\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Use Keras 2.\\n\",\n        \"version_fn = getattr(tf.keras, \\\"version\\\", None)\\n\",\n        \"if version_fn and version_fn().startswith(\\\"3.\\\"):\\n\",\n        \"  import tf_keras as keras\\n\",\n        \"else:\\n\",\n        \"  keras = tf.keras\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"DFN6GOcBAqzv\"\n      },\n      \"source\": [\n        \"Default values used in this tutorial:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"9uqMM2joAnoW\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Default number of training epochs, batch sizes and learning rate.\\n\",\n        \"NUM_EPOCHS = 256\\n\",\n        \"BATCH_SIZE = 256\\n\",\n        \"LEARNING_RATES = 0.01\\n\",\n        \"# Directory containing dataset files.\\n\",\n        \"DATA_DIR = 'https://raw.githubusercontent.com/serenalwang/shape_constraints_for_ethics/master'\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"OZJQfJvY3ibC\"\n      },\n      \"source\": [\n        \"# Case study #1: Law school admissions\\n\",\n        \"\\n\",\n        \"In the first part of this tutorial, we will consider a case study using the Law\\n\",\n        \"School Admissions dataset from the Law School Admissions Council (LSAC). We will\\n\",\n        \"train a classifier to predict whether or not a student will pass the bar using\\n\",\n        \"two features: the student's LSAT score and undergraduate GPA.\\n\",\n        \"\\n\",\n        \"Suppose that the classifier’s score was used to guide law school admissions or\\n\",\n        \"scholarships. According to merit-based social norms, we would expect that\\n\",\n        \"students with higher GPA and higher LSAT score should receive a higher score\\n\",\n        \"from the classifier. However, we will observe that it is easy for models to\\n\",\n        \"violate these intuitive norms, and sometimes penalize people for having a higher\\n\",\n        \"GPA or LSAT score.\\n\",\n        \"\\n\",\n        \"To address this *unfair penalization* problem, we can impose monotonicity\\n\",\n        \"constraints so that a model never penalizes higher GPA or higher LSAT score, all\\n\",\n        \"else equal. In this tutorial, we will show how to impose those monotonicity\\n\",\n        \"constraints using TFL.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"vJES8lYT1fHN\"\n      },\n      \"source\": [\n        \"## Load Law School Data\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Cl89ZOsQ14An\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Load data file.\\n\",\n        \"law_file_name = 'lsac.csv'\\n\",\n        \"law_file_path = os.path.join(DATA_DIR, law_file_name)\\n\",\n        \"raw_law_df = pd.read_csv(law_file_path, delimiter=',')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RCpTYCNjqOsC\"\n      },\n      \"source\": [\n        \"Preprocess dataset:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"jdY5rtLs4xQK\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Define label column name.\\n\",\n        \"LAW_LABEL = 'pass_bar'\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"1t1Hd8gu6Uat\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def preprocess_law_data(input_df):\\n\",\n        \"  # Drop rows with where the label or features of interest are missing.\\n\",\n        \"  output_df = input_df[~input_df[LAW_LABEL].isna() \\u0026 ~input_df['ugpa'].isna() \\u0026\\n\",\n        \"                       (input_df['ugpa'] \\u003e 0) \\u0026 ~input_df['lsat'].isna()]\\n\",\n        \"  return output_df\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"law_df = preprocess_law_data(raw_law_df)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"YhvSKr9SCrHP\"\n      },\n      \"source\": [\n        \"### Split data into train/validation/test sets\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"gQKkIGD-CvGD\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def split_dataset(input_df, random_state=888):\\n\",\n        \"  \\\"\\\"\\\"Splits an input dataset into train, val, and test sets.\\\"\\\"\\\"\\n\",\n        \"  train_df, test_val_df = train_test_split(\\n\",\n        \"      input_df, test_size=0.3, random_state=random_state\\n\",\n        \"  )\\n\",\n        \"  val_df, test_df = train_test_split(\\n\",\n        \"      test_val_df, test_size=0.66, random_state=random_state\\n\",\n        \"  )\\n\",\n        \"  return train_df, val_df, test_df\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"dataframes = {}\\n\",\n        \"datasets = {}\\n\",\n        \"\\n\",\n        \"(dataframes['law_train'], dataframes['law_val'], dataframes['law_test']) = (\\n\",\n        \"    split_dataset(law_df)\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"for df_name, df in dataframes.items():\\n\",\n        \"  datasets[df_name] = tf.data.Dataset.from_tensor_slices(\\n\",\n        \"      ((df[['ugpa']], df[['lsat']]), df[['pass_bar']])\\n\",\n        \"  ).batch(BATCH_SIZE)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"zObwzY7f3aLy\"\n      },\n      \"source\": [\n        \"### Visualize data distribution\\n\",\n        \"\\n\",\n        \"First we will visualize the distribution of the data. We will plot the GPA and\\n\",\n        \"LSAT scores for all students that passed the bar and also for all students that\\n\",\n        \"did not pass the bar.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"dRAZB5cLORUG\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def plot_dataset_contour(input_df, title):\\n\",\n        \"  plt.rcParams['font.family'] = ['serif']\\n\",\n        \"  g = sns.jointplot(\\n\",\n        \"      x='ugpa',\\n\",\n        \"      y='lsat',\\n\",\n        \"      data=input_df,\\n\",\n        \"      kind='kde',\\n\",\n        \"      xlim=[1.4, 4],\\n\",\n        \"      ylim=[0, 50])\\n\",\n        \"  g.plot_joint(plt.scatter, c='b', s=10, linewidth=1, marker='+')\\n\",\n        \"  g.ax_joint.collections[0].set_alpha(0)\\n\",\n        \"  g.set_axis_labels('Undergraduate GPA', 'LSAT score', fontsize=14)\\n\",\n        \"  g.fig.suptitle(title, fontsize=14)\\n\",\n        \"  # Adust plot so that the title fits.\\n\",\n        \"  plt.subplots_adjust(top=0.9)\\n\",\n        \"  plt.show()\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"feovlsWPQhVG\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"law_df_pos = law_df[law_df[LAW_LABEL] == 1]\\n\",\n        \"plot_dataset_contour(\\n\",\n        \"    law_df_pos, title='Distribution of students that passed the bar')\\n\",\n        \"law_df_neg = law_df[law_df[LAW_LABEL] == 0]\\n\",\n        \"plot_dataset_contour(\\n\",\n        \"    law_df_neg, title='Distribution of students that failed the bar')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"6grrFEMPfPjk\"\n      },\n      \"source\": [\n        \"## Train calibrated lattice model to predict bar exam passage\\n\",\n        \"\\n\",\n        \"Next, we will train a *calibrated lattice model* from TFL to predict whether or\\n\",\n        \"not a student will pass the bar. The two input features will be LSAT score and\\n\",\n        \"undergraduate GPA, and the training label will be whether the student passed the\\n\",\n        \"bar.\\n\",\n        \"\\n\",\n        \"We will first train a calibrated lattice model without any constraints. Then, we\\n\",\n        \"will train a calibrated lattice model with monotonicity constraints and observe\\n\",\n        \"the difference in the model output and accuracy.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"HSfAwgiO_6YA\"\n      },\n      \"source\": [\n        \"### Helper functions for visualization of trained model outputs\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"aw28Xc7IS6vR\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def plot_model_contour(model, from_logits=False, num_keypoints=20):\\n\",\n        \"  x = np.linspace(min(law_df['ugpa']), max(law_df['ugpa']), num_keypoints)\\n\",\n        \"  y = np.linspace(min(law_df['lsat']), max(law_df['lsat']), num_keypoints)\\n\",\n        \"\\n\",\n        \"  x_grid, y_grid = np.meshgrid(x, y)\\n\",\n        \"\\n\",\n        \"  positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\\n\",\n        \"  plot_df = pd.DataFrame(positions.T, columns=['ugpa', 'lsat'])\\n\",\n        \"  plot_df[LAW_LABEL] = np.ones(len(plot_df))\\n\",\n        \"  predictions = model.predict((plot_df[['ugpa']], plot_df[['lsat']]))\\n\",\n        \"  if from_logits:\\n\",\n        \"    predictions = tf.math.sigmoid(predictions)\\n\",\n        \"  grid_predictions = np.reshape(predictions, x_grid.shape)\\n\",\n        \"\\n\",\n        \"  plt.rcParams['font.family'] = ['serif']\\n\",\n        \"  plt.contour(\\n\",\n        \"      x_grid,\\n\",\n        \"      y_grid,\\n\",\n        \"      grid_predictions,\\n\",\n        \"      colors=('k',),\\n\",\n        \"      levels=np.linspace(0, 1, 11),\\n\",\n        \"  )\\n\",\n        \"  plt.contourf(\\n\",\n        \"      x_grid,\\n\",\n        \"      y_grid,\\n\",\n        \"      grid_predictions,\\n\",\n        \"      cmap=plt.cm.bone,\\n\",\n        \"      levels=np.linspace(0, 1, 11),\\n\",\n        \"  )\\n\",\n        \"  plt.xticks(fontsize=20)\\n\",\n        \"  plt.yticks(fontsize=20)\\n\",\n        \"\\n\",\n        \"  cbar = plt.colorbar()\\n\",\n        \"  cbar.ax.set_ylabel('Model score', fontsize=20)\\n\",\n        \"  cbar.ax.tick_params(labelsize=20)\\n\",\n        \"\\n\",\n        \"  plt.xlabel('Undergraduate GPA', fontsize=20)\\n\",\n        \"  plt.ylabel('LSAT score', fontsize=20)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fAMSCaRHIn1w\"\n      },\n      \"source\": [\n        \"## Train unconstrained (non-monotonic) calibrated lattice model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"mK7RWDJ5ugdd\"\n      },\n      \"source\": [\n        \"We create a TFL premade model using a '`CalibratedLatticeConfig`. This model is a calibrated lattice model with an output calibration.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"J16TOicHQ1sM\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name='ugpa',\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=16,\\n\",\n        \"            monotonicity=0,\\n\",\n        \"            pwl_calibration_always_monotonic=False,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name='lsat',\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=16,\\n\",\n        \"            monotonicity=0,\\n\",\n        \"            pwl_calibration_always_monotonic=False,\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=8),\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"jt1Rm6qCuuat\"\n      },\n      \"source\": [\n        \"We calculate and populate feature quantiles in the feature configs using the `premade_lib` API.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"eSELqBdURE0F\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    features=dataframes['law_train'][['ugpa', 'lsat', 'pass_bar']],\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ahV2Sn0Xz1aO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"nomon_lattice_model = tfl.premade.CalibratedLattice(model_config=model_config)\\n\",\n        \"keras.utils.plot_model(\\n\",\n        \"    nomon_lattice_model, expand_nested=True, show_layer_names=False, rankdir=\\\"LR\\\"\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Oc5f-6zNtyxr\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"nomon_lattice_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[\\n\",\n        \"        keras.metrics.BinaryAccuracy(name='accuracy'),\\n\",\n        \"    ],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATES),\\n\",\n        \")\\n\",\n        \"nomon_lattice_model.fit(datasets['law_train'], epochs=NUM_EPOCHS, verbose=0)\\n\",\n        \"\\n\",\n        \"train_acc = nomon_lattice_model.evaluate(datasets['law_train'])[1]\\n\",\n        \"val_acc = nomon_lattice_model.evaluate(datasets['law_val'])[1]\\n\",\n        \"test_acc = nomon_lattice_model.evaluate(datasets['law_test'])[1]\\n\",\n        \"print(\\n\",\n        \"    'accuracies for train: %f, val: %f, test: %f'\\n\",\n        \"    % (train_acc, val_acc, test_acc)\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"LuFxP9lDTZup\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_model_contour(nomon_lattice_model, from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"eKVkjHg_LaWb\"\n      },\n      \"source\": [\n        \"## Train monotonic calibrated lattice model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"W42OXWLVwx3w\"\n      },\n      \"source\": [\n        \"We can get a monotonic model by setting the monotonicity constraints in feature configs.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"XeOKlPRc0BQe\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_config.feature_configs[0].monotonicity = 1\\n\",\n        \"model_config.feature_configs[1].monotonicity = 1\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"C_MUEvGNp6g2\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"mon_lattice_model = tfl.premade.CalibratedLattice(model_config=model_config)\\n\",\n        \"\\n\",\n        \"mon_lattice_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[\\n\",\n        \"        keras.metrics.BinaryAccuracy(name='accuracy'),\\n\",\n        \"    ],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATES),\\n\",\n        \")\\n\",\n        \"mon_lattice_model.fit(datasets['law_train'], epochs=NUM_EPOCHS, verbose=0)\\n\",\n        \"\\n\",\n        \"train_acc = mon_lattice_model.evaluate(datasets['law_train'])[1]\\n\",\n        \"val_acc = mon_lattice_model.evaluate(datasets['law_val'])[1]\\n\",\n        \"test_acc = mon_lattice_model.evaluate(datasets['law_test'])[1]\\n\",\n        \"print(\\n\",\n        \"    'accuracies for train: %f, val: %f, test: %f'\\n\",\n        \"    % (train_acc, val_acc, test_acc)\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ABdhYOUVCXzD\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_model_contour(mon_lattice_model, from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"GWzBEV_p0WE-\"\n      },\n      \"source\": [\n        \"We demonstrated that TFL calibrated lattice models could be trained to be\\n\",\n        \"monotonic in both LSAT score and GPA without too big of a sacrifice in accuracy.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fsI14lrFxRha\"\n      },\n      \"source\": [\n        \"## Train other unconstrained models\\n\",\n        \"\\n\",\n        \"How does the calibrated lattice model compare to other types of models, like\\n\",\n        \"deep neural networks (DNNs) or gradient boosted trees (GBTs)? Do DNNs and GBTs\\n\",\n        \"appear to have reasonably fair outputs? To address this question, we will next\\n\",\n        \"train an unconstrained DNN and GBT. In fact, we will observe that the DNN and\\n\",\n        \"GBT both easily violate monotonicity in LSAT score and undergraduate GPA.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"uo1ruWXcvUqb\"\n      },\n      \"source\": [\n        \"### Train an unconstrained Deep Neural Network (DNN) model\\n\",\n        \"\\n\",\n        \"The architecture was previously optimized to achieve high validation accuracy.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"3pplraob0Od-\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"keras.utils.set_random_seed(42)\\n\",\n        \"inputs = [\\n\",\n        \"    keras.Input(shape=(1,), dtype=tf.float32),\\n\",\n        \"    keras.Input(shape=(1), dtype=tf.float32),\\n\",\n        \"]\\n\",\n        \"inputs_flat = keras.layers.Concatenate()(inputs)\\n\",\n        \"dense_layers = keras.Sequential(\\n\",\n        \"    [\\n\",\n        \"        keras.layers.Dense(64, activation='relu'),\\n\",\n        \"        keras.layers.Dense(32, activation='relu'),\\n\",\n        \"        keras.layers.Dense(1, activation=None),\\n\",\n        \"    ],\\n\",\n        \"    name='dense_layers',\\n\",\n        \")\\n\",\n        \"dnn_model = keras.Model(inputs=inputs, outputs=dense_layers(inputs_flat))\\n\",\n        \"dnn_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[keras.metrics.BinaryAccuracy(name='accuracy')],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATES),\\n\",\n        \")\\n\",\n        \"dnn_model.fit(datasets['law_train'], epochs=NUM_EPOCHS, verbose=0)\\n\",\n        \"\\n\",\n        \"train_acc = dnn_model.evaluate(datasets['law_train'])[1]\\n\",\n        \"val_acc = dnn_model.evaluate(datasets['law_val'])[1]\\n\",\n        \"test_acc = dnn_model.evaluate(datasets['law_test'])[1]\\n\",\n        \"print(\\n\",\n        \"    'accuracies for train: %f, val: %f, test: %f'\\n\",\n        \"    % (train_acc, val_acc, test_acc)\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"LwPQqLt-E7R4\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_model_contour(dnn_model, from_logits=True)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"OOAKK0_3vWir\"\n      },\n      \"source\": [\n        \"### Train an unconstrained Gradient Boosted Trees (GBT) model\\n\",\n        \"\\n\",\n        \"The tree structure was previously optimized to achieve high validation accuracy.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"6UrCJHqhgd3o\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"tree_model = tfdf.keras.GradientBoostedTreesModel(\\n\",\n        \"    exclude_non_specified_features=False,\\n\",\n        \"    num_threads=1,\\n\",\n        \"    num_trees=20,\\n\",\n        \"    max_depth=4,\\n\",\n        \"    growing_strategy='BEST_FIRST_GLOBAL',\\n\",\n        \"    random_seed=42,\\n\",\n        \"    temp_directory=tempfile.mkdtemp(),\\n\",\n        \")\\n\",\n        \"tree_model.compile(metrics=[keras.metrics.BinaryAccuracy(name='accuracy')])\\n\",\n        \"tree_model.fit(\\n\",\n        \"    datasets['law_train'], validation_data=datasets['law_val'], verbose=0\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"tree_train_acc = tree_model.evaluate(datasets['law_train'], verbose=0)[1]\\n\",\n        \"tree_val_acc = tree_model.evaluate(datasets['law_val'], verbose=0)[1]\\n\",\n        \"tree_test_acc = tree_model.evaluate(datasets['law_test'], verbose=0)[1]\\n\",\n        \"print(\\n\",\n        \"    'accuracies for GBT: train: %f, val: %f, test: %f'\\n\",\n        \"    % (tree_train_acc, tree_val_acc, tree_test_acc)\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"AZFyfQT1E_nR\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_model_contour(tree_model)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"uX2qiMlrY8aO\"\n      },\n      \"source\": [\n        \"# Case study #2: Credit Default\\n\",\n        \"\\n\",\n        \"The second case study that we will consider in this tutorial is predicting an\\n\",\n        \"individual's credit default probability. We will use the Default of Credit Card\\n\",\n        \"Clients dataset from the UCI repository. This data was collected from 30,000\\n\",\n        \"Taiwanese credit card users and contains a binary label of whether or not a user\\n\",\n        \"defaulted on a payment in a time window. Features include marital status,\\n\",\n        \"gender, education, and how long a user is behind on payment of their existing\\n\",\n        \"bills, for each of the months of April-September 2005.\\n\",\n        \"\\n\",\n        \"As we did with the first case study, we again illustrate using monotonicity\\n\",\n        \"constraints to avoid *unfair penalization*: if the model were to be used to\\n\",\n        \"determine a user’s credit score, it could feel unfair to many if they were\\n\",\n        \"penalized for paying their bills sooner, all else equal. Thus, we apply a\\n\",\n        \"monotonicity constraint that keeps the model from penalizing early payments.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"tz5yduNuFinA\"\n      },\n      \"source\": [\n        \"## Load Credit Default data\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"KuylMNBCILwy\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Load data file.\\n\",\n        \"credit_file_name = 'credit_default.csv'\\n\",\n        \"credit_file_path = os.path.join(DATA_DIR, credit_file_name)\\n\",\n        \"credit_df = pd.read_csv(credit_file_path, delimiter=',')\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Hv_GQcEHIf9v\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Define label column name.\\n\",\n        \"CREDIT_LABEL = 'default'\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"13oZWY0YIoy3\"\n      },\n      \"source\": [\n        \"### Split data into train/validation/test sets\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"dty5tXJqIscz\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dfs = {}\\n\",\n        \"datasets = {}\\n\",\n        \"\\n\",\n        \"dfs[\\\"credit_train\\\"], dfs[\\\"credit_val\\\"], dfs[\\\"credit_test\\\"] = split_dataset(\\n\",\n        \"    credit_df\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"for df_name, df in dfs.items():\\n\",\n        \"  datasets[df_name] = tf.data.Dataset.from_tensor_slices(\\n\",\n        \"      ((df[['MARRIAGE']], df[['PAY_0']]), df[['default']])\\n\",\n        \"  ).batch(BATCH_SIZE)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"_kAciWXHKGV7\"\n      },\n      \"source\": [\n        \"### Visualize data distribution\\n\",\n        \"\\n\",\n        \"First we will visualize the distribution of the data. We will plot the mean and\\n\",\n        \"standard error of the observed default rate for people with different marital\\n\",\n        \"statuses and repayment statuses. The repayment status represents the number of\\n\",\n        \"months a person is behind on paying back their loan (as of April 2005).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"8CxacQxnkHWE\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def get_agg_data(df, x_col, y_col, bins=11):\\n\",\n        \"  xbins = pd.cut(df[x_col], bins=bins)\\n\",\n        \"  data = df[[x_col, y_col]].groupby(xbins).agg(['mean', 'sem'])\\n\",\n        \"  return data\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def plot_2d_means_credit(input_df, x_col, y_col, x_label, y_label):\\n\",\n        \"  plt.rcParams['font.family'] = ['serif']\\n\",\n        \"  _, ax = plt.subplots(nrows=1, ncols=1)\\n\",\n        \"  plt.setp(ax.spines.values(), color='black', linewidth=1)\\n\",\n        \"  ax.tick_params(\\n\",\n        \"      direction='in', length=6, width=1, top=False, right=False, labelsize=18)\\n\",\n        \"  df_single = get_agg_data(input_df[input_df['MARRIAGE'] == 1], x_col, y_col)\\n\",\n        \"  df_married = get_agg_data(input_df[input_df['MARRIAGE'] == 2], x_col, y_col)\\n\",\n        \"  ax.errorbar(\\n\",\n        \"      df_single[(x_col, 'mean')],\\n\",\n        \"      df_single[(y_col, 'mean')],\\n\",\n        \"      xerr=df_single[(x_col, 'sem')],\\n\",\n        \"      yerr=df_single[(y_col, 'sem')],\\n\",\n        \"      color='orange',\\n\",\n        \"      marker='s',\\n\",\n        \"      capsize=3,\\n\",\n        \"      capthick=1,\\n\",\n        \"      label='Single',\\n\",\n        \"      markersize=10,\\n\",\n        \"      linestyle='')\\n\",\n        \"  ax.errorbar(\\n\",\n        \"      df_married[(x_col, 'mean')],\\n\",\n        \"      df_married[(y_col, 'mean')],\\n\",\n        \"      xerr=df_married[(x_col, 'sem')],\\n\",\n        \"      yerr=df_married[(y_col, 'sem')],\\n\",\n        \"      color='b',\\n\",\n        \"      marker='^',\\n\",\n        \"      capsize=3,\\n\",\n        \"      capthick=1,\\n\",\n        \"      label='Married',\\n\",\n        \"      markersize=10,\\n\",\n        \"      linestyle='')\\n\",\n        \"  leg = ax.legend(loc='upper left', fontsize=18, frameon=True, numpoints=1)\\n\",\n        \"  ax.set_xlabel(x_label, fontsize=18)\\n\",\n        \"  ax.set_ylabel(y_label, fontsize=18)\\n\",\n        \"  ax.set_ylim(0, 1.1)\\n\",\n        \"  ax.set_xlim(-2, 8.5)\\n\",\n        \"  ax.patch.set_facecolor('white')\\n\",\n        \"  leg.get_frame().set_edgecolor('black')\\n\",\n        \"  leg.get_frame().set_facecolor('white')\\n\",\n        \"  leg.get_frame().set_linewidth(1)\\n\",\n        \"  plt.show()\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"VHXyYbyekKLT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_2d_means_credit(\\n\",\n        \"    dfs['credit_train'],\\n\",\n        \"    'PAY_0',\\n\",\n        \"    'default',\\n\",\n        \"    'Repayment Status (April)',\\n\",\n        \"    'Observed default rate',\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"4hnZBigB7kzY\"\n      },\n      \"source\": [\n        \"## Train calibrated lattice model to predict credit default rate\\n\",\n        \"\\n\",\n        \"Next, we will train a *calibrated lattice model* from TFL to predict whether or\\n\",\n        \"not a person will default on a loan. The two input features will be the person's\\n\",\n        \"marital status and how many months the person is behind on paying back their\\n\",\n        \"loans in April (repayment status). The training label will be whether or not the\\n\",\n        \"person defaulted on a loan.\\n\",\n        \"\\n\",\n        \"We will first train a calibrated lattice model without any constraints. Then, we\\n\",\n        \"will train a calibrated lattice model with monotonicity constraints and observe\\n\",\n        \"the difference in the model output and accuracy.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"iwxnlRrQPdTg\"\n      },\n      \"source\": [\n        \"### Helper functions for visualization of trained model outputs\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"zVGxEfbhPZ5H\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def plot_predictions_credit(\\n\",\n        \"    input_df,\\n\",\n        \"    model,\\n\",\n        \"    x_col,\\n\",\n        \"    x_label='Repayment Status (April)',\\n\",\n        \"    y_label='Predicted default probability',\\n\",\n        \"):\\n\",\n        \"  predictions = model.predict((input_df[['MARRIAGE']], input_df[['PAY_0']]))\\n\",\n        \"  predictions = tf.math.sigmoid(predictions)\\n\",\n        \"  new_df = input_df.copy()\\n\",\n        \"  new_df.loc[:, 'predictions'] = predictions\\n\",\n        \"  plot_2d_means_credit(new_df, x_col, 'predictions', x_label, y_label)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"UMIpywE1P07H\"\n      },\n      \"source\": [\n        \"## Train unconstrained (non-monotonic) calibrated lattice model\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"cxGu3gBOApOm\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_config = tfl.configs.CalibratedLatticeConfig(\\n\",\n        \"    feature_configs=[\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name='MARRIAGE',\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=2,\\n\",\n        \"            monotonicity=0,\\n\",\n        \"            pwl_calibration_always_monotonic=False,\\n\",\n        \"        ),\\n\",\n        \"        tfl.configs.FeatureConfig(\\n\",\n        \"            name='PAY_0',\\n\",\n        \"            lattice_size=3,\\n\",\n        \"            pwl_calibration_num_keypoints=16,\\n\",\n        \"            monotonicity=0,\\n\",\n        \"            pwl_calibration_always_monotonic=False,\\n\",\n        \"        ),\\n\",\n        \"    ],\\n\",\n        \"    output_calibration=True,\\n\",\n        \"    output_initialization=np.linspace(-2, 2, num=8),\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"cVZKH36LA8BQ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"feature_keypoints = tfl.premade_lib.compute_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    features=dfs[\\\"credit_train\\\"][['MARRIAGE', 'PAY_0', 'default']],\\n\",\n        \")\\n\",\n        \"tfl.premade_lib.set_feature_keypoints(\\n\",\n        \"    feature_configs=model_config.feature_configs,\\n\",\n        \"    feature_keypoints=feature_keypoints,\\n\",\n        \"    add_missing_feature_configs=False,\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"2It6hvNRA8Bi\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"nomon_lattice_model = tfl.premade.CalibratedLattice(model_config=model_config)\\n\",\n        \"\\n\",\n        \"nomon_lattice_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[\\n\",\n        \"        keras.metrics.BinaryAccuracy(name='accuracy'),\\n\",\n        \"    ],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATES),\\n\",\n        \")\\n\",\n        \"nomon_lattice_model.fit(datasets['credit_train'], epochs=NUM_EPOCHS, verbose=0)\\n\",\n        \"\\n\",\n        \"train_acc = nomon_lattice_model.evaluate(datasets['credit_train'])[1]\\n\",\n        \"val_acc = nomon_lattice_model.evaluate(datasets['credit_val'])[1]\\n\",\n        \"test_acc = nomon_lattice_model.evaluate(datasets['credit_test'])[1]\\n\",\n        \"print(\\n\",\n        \"    'accuracies for train: %f, val: %f, test: %f'\\n\",\n        \"    % (train_acc, val_acc, test_acc)\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"5zQ_jm75kRX6\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_predictions_credit(dfs['credit_train'], nomon_lattice_model, 'PAY_0')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0aokp7qLQBIr\"\n      },\n      \"source\": [\n        \"## Train monotonic calibrated lattice model\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"MbB2ixYMC6Za\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model_config.feature_configs[0].monotonicity = 1\\n\",\n        \"model_config.feature_configs[1].monotonicity = 1\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"wWCG7YrLUZDH\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"mon_lattice_model = tfl.premade.CalibratedLattice(model_config=model_config)\\n\",\n        \"\\n\",\n        \"mon_lattice_model.compile(\\n\",\n        \"    loss=keras.losses.BinaryCrossentropy(from_logits=True),\\n\",\n        \"    metrics=[\\n\",\n        \"        keras.metrics.BinaryAccuracy(name='accuracy'),\\n\",\n        \"    ],\\n\",\n        \"    optimizer=keras.optimizers.Adam(LEARNING_RATES),\\n\",\n        \")\\n\",\n        \"mon_lattice_model.fit(datasets['credit_train'], epochs=NUM_EPOCHS, verbose=0)\\n\",\n        \"\\n\",\n        \"train_acc = mon_lattice_model.evaluate(datasets['credit_train'])[1]\\n\",\n        \"val_acc = mon_lattice_model.evaluate(datasets['credit_val'])[1]\\n\",\n        \"test_acc = mon_lattice_model.evaluate(datasets['credit_test'])[1]\\n\",\n        \"print(\\n\",\n        \"    'accuracies for train: %f, val: %f, test: %f'\\n\",\n        \"    % (train_acc, val_acc, test_acc)\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"JCQ2eMdndFhR\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plot_predictions_credit(dfs['credit_train'], mon_lattice_model, 'PAY_0')\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"name\": \"shape_constraints_for_ethics.ipynb\",\n      \"private_outputs\": true,\n      \"provenance\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/BUILD",
    "content": "# Copyright 2019 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nload(\"//third_party/bazel_rules/rules_python/python:py_binary.bzl\", \"py_binary\")\n\nlicenses([\"notice\"])\n\npackage(\n    default_visibility = [\n        \"//tensorflow_lattice:__subpackages__\",\n    ],\n)\n\npy_binary(\n    name = \"keras_sequential_uci_heart\",\n    srcs = [\"keras_sequential_uci_heart.py\"],\n    python_version = \"PY3\",\n    deps = [\n        # tensorflow dep,\n        \"//tensorflow_lattice\",\n    ],\n)\n\npy_binary(\n    name = \"keras_functional_uci_heart\",\n    srcs = [\"keras_functional_uci_heart.py\"],\n    python_version = \"PY3\",\n    deps = [\n        # tensorflow dep,\n        \"//tensorflow_lattice\",\n    ],\n)\n"
  },
  {
    "path": "examples/keras_functional_uci_heart.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Example usage of TFL within Keras Functional API.\n\nThis example builds and trains a calibrated lattice model for the UCI heart\ndataset.\n\n\"Calibrated lattice\" is a commonly used architecture for datasets where number\nof input features does not exceed ~15.\n\n\"Calibrated lattice\" assumes every feature being transformed by PWLCalibration\nor CategoricalCalibration layers before nonlineary fusing result of calibration\nwithin a lattice layer.\n\nThe TFL package does not have any layers dedicated to processing of sparse\nfeatures but thanks to plug and play compatibility with any other Keras layers\nwe can take advantage of standard Keras embedding to handle sparse features. UCI\nHeart dataset does not have any sparse features so for this example we replaced\nPWLCalibration layer for feature 'age' with Embedding layer in order to\ndemonstrate such compatibility as well as advantage of monotonicity\nconstraints for semantically meaningful features.\n\nGenerally when you manually combine TFL layers you should keep track of:\n1) Ensuring that inputs to TFL layers are within expected range.\n  - Input range for PWLCalibration layer is defined by smallest and largest of\n    provided keypoints.\n  - Input range for Lattice layer is [0.0, lattice_sizes[d] - 1.0] for any\n    dimension d.\n  TFL layers can constraint their output to be within desired range. Feeding\n  output of other layers into TFL layers you might want to ensure that something\n  like sigmoid is used to constraint their output range.\n2) Properly configure monotonicity. If your calibration layer is monotonic then\n  corresponding dimension of lattice layer should also be monotonic.\n\nThis example uses functional API for Keras model construction. For an example of\nsequential models with TFL layers see keras_sequential_uci_heart.py.\n\nIn order to see how better generalization can be achieved with a properly\nconstrained PWLCalibration layer compared to a vanila embedding layer, compare\ntraining and validation losses of this model with one defined in\nkeras_sequential_uci_heart.py\n\nNote that the specifics of layer configurations are for demonstration purposes\nand might not result in optimal performance.\n\nExample usage:\nkeras_functional_uci_heart\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import app\nfrom absl import flags\n\nimport numpy as np\nimport pandas as pd\n\nimport tensorflow as tf\nimport tensorflow_lattice as tfl\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nFLAGS = flags.FLAGS\nflags.DEFINE_integer('num_epochs', 200, 'Number of training epoch.')\n\n\ndef main(_):\n  # UCI Statlog (Heart) dataset.\n  csv_file = keras.utils.get_file(\n      'heart.csv',\n      'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv',\n  )\n  training_data_df = pd.read_csv(csv_file).sample(\n      frac=1.0, random_state=41).reset_index(drop=True)\n\n  # Feature columns.\n  # 0  age\n  # 1  sex\n  # 2  cp        chest pain type (4 values)\n  # 3  trestbps  resting blood pressure\n  # 4  chol      serum cholestoral in mg/dl\n  # 5  fbs       fasting blood sugar > 120 mg/dl\n  # 6  restecg   resting electrocardiographic results (values 0,1,2)\n  # 7  thalach   maximum heart rate achieved\n  # 8  exang     exercise induced angina\n  # 9  oldpeak   ST depression induced by exercise relative to rest\n  # 10 slope     the slope of the peak exercise ST segment\n  # 11 ca        number of major vessels (0-3) colored by flourosopy\n  # 12 thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n\n  # Example slice of training data:\n  #     age  sex  cp  trestbps  chol  fbs  restecg  thalach  exang  oldpeak\n  # 0   63    1   1       145   233    1        2      150      0      2.3\n  # 1   67    1   4       160   286    0        2      108      1      1.5\n  # 2   67    1   4       120   229    0        2      129      1      2.6\n  # 3   37    1   3       130   250    0        0      187      0      3.5\n  # 4   41    0   2       130   204    0        2      172      0      1.4\n  # 5   56    1   2       120   236    0        0      178      0      0.8\n  # 6   62    0   4       140   268    0        2      160      0      3.6\n  # 7   57    0   4       120   354    0        0      163      1      0.6\n  # 8   63    1   4       130   254    0        2      147      0      1.4\n  # 9   53    1   4       140   203    1        2      155      1      3.1\n\n  model_inputs = []\n  lattice_inputs = []\n  # We are going to have 2-d embedding as one of lattice inputs.\n  lattice_sizes_for_embedding = [2, 3]\n  lattice_sizes = lattice_sizes_for_embedding + [2, 2, 3, 3, 2, 2]\n\n  # ############### age ###############\n\n  age_input = keras.layers.Input(shape=[1])\n  model_inputs.append(age_input)\n  age_embedding = keras.layers.Embedding(\n      input_dim=10,\n      output_dim=len(lattice_sizes_for_embedding),\n      embeddings_initializer=keras.initializers.RandomNormal(seed=1))(\n          age_input)\n  # Flatten to get rid of redundant tensor dimension created by embedding layer.\n  age_embedding = keras.layers.Flatten()(age_embedding)\n\n  # Lattice expects input data for lattice dimension d to be within\n  # [0, lattice_sizes[d]-1.0]. Apply sigmoid and multiply it by input range to\n  # ensure that lattice inputs are within expected range.\n  embedding_lattice_input_range = tf.constant(\n      [size - 1.0 for size in lattice_sizes_for_embedding],\n      # Insert dimension of size 1 in front to ensure that batch dimension\n      # will not collapse as result of multiplication.\n      shape=(1, 2))\n  age_ranged = keras.layers.multiply(\n      [keras.activations.sigmoid(age_embedding), embedding_lattice_input_range])\n  lattice_inputs.append(age_ranged)\n\n  # ############### sex ###############\n\n  # For boolean features simply specify CategoricalCalibration layer with 2\n  # buckets.\n  sex_input = keras.layers.Input(shape=[1])\n  model_inputs.append(sex_input)\n  sex_calibrator = tfl.layers.CategoricalCalibration(\n      num_buckets=2,\n      output_min=0.0,\n      output_max=lattice_sizes[2] - 1.0,\n      # Initializes all outputs to (output_min + output_max) / 2.0.\n      kernel_initializer='constant',\n  )(\n      sex_input)\n  lattice_inputs.append(sex_calibrator)\n\n  # ############### cp ###############\n\n  cp_input = keras.layers.Input(shape=[1])\n  model_inputs.append(cp_input)\n  cp_calibrator = tfl.layers.PWLCalibration(\n      # Here instead of specifying dtype of layer we convert keypoints into\n      # np.float32.\n      input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32),\n      output_min=0.0,\n      output_max=lattice_sizes[3] - 1.0,\n      monotonicity='increasing',\n      # You can specify TFL regularizers as tuple ('regularizer name', l1, l2).\n      kernel_regularizer=('hessian', 0.0, 1e-4))(\n          cp_input)\n  lattice_inputs.append(cp_calibrator)\n\n  # ############### trestbps ###############\n\n  trestbps_input = keras.layers.Input(shape=[1])\n  model_inputs.append(trestbps_input)\n  trestbps_calibrator = tfl.layers.PWLCalibration(\n      # Alternatively to uniform keypoints you might want to use quantiles as\n      # keypoints.\n      input_keypoints=np.quantile(training_data_df['trestbps'],\n                                  np.linspace(0.0, 1.0, num=5)),\n      dtype=tf.float32,\n      # Together with quantile keypoints you might want to initialize piecewise\n      # linear function to have 'equal_slopes' in order for output of layer\n      # after initialization to preserve original distribution.\n      kernel_initializer='equal_slopes',\n      output_min=0.0,\n      output_max=lattice_sizes[4] - 1.0,\n      # You might consider clamping extreme inputs of the calibrator to output\n      # bounds.\n      clamp_min=True,\n      clamp_max=True,\n      monotonicity='increasing',\n  )(\n      trestbps_input)\n  lattice_inputs.append(trestbps_calibrator)\n\n  # ############### chol ###############\n\n  chol_input = keras.layers.Input(shape=[1])\n  model_inputs.append(chol_input)\n  chol_calibrator = tfl.layers.PWLCalibration(\n      # Explicit input keypoint initialization.\n      input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n      output_min=0.0,\n      output_max=lattice_sizes[5] - 1.0,\n      # Monotonicity of calibrator can be decreasing. Note that corresponding\n      # lattice dimension must have INCREASING monotonicity regardless of\n      # monotonicity direction of calibrator.\n      # Its not some weird configuration hack. Its just how math works :)\n      monotonicity='decreasing',\n      # Convexity together with decreasing monotonicity result in diminishing\n      # return constraint.\n      convexity='convex',\n      # You can specify list of regularizers. You are not limited to TFL\n      # regularizrs. Feel free to use any :)\n      kernel_regularizer=[('laplacian', 0.0, 1e-4),\n                          keras.regularizers.l1_l2(l1=0.001)])(\n                              chol_input)\n  lattice_inputs.append(chol_calibrator)\n\n  # ############### fbs ###############\n\n  fbs_input = keras.layers.Input(shape=[1])\n  model_inputs.append(fbs_input)\n  fbs_calibrator = tfl.layers.CategoricalCalibration(\n      num_buckets=2,\n      output_min=0.0,\n      output_max=lattice_sizes[6] - 1.0,\n      # For categorical calibration layer monotonicity is specified for pairs\n      # of indices of categories. Output for first category in pair will be\n      # smaller than output for second category.\n      #\n      # Don't forget to set monotonicity of corresponding dimension of Lattice\n      # layer to 'increasing'.\n      monotonicities=[(0, 1)],\n      # This initializer is identical to default one ('uniform'), but has fixed\n      # seed in order to simplify experimentation.\n      kernel_initializer=keras.initializers.RandomUniform(\n          minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1),\n  )(\n      fbs_input)\n  lattice_inputs.append(fbs_calibrator)\n\n  # ############### restecg ###############\n\n  restecg_input = keras.layers.Input(shape=[1])\n  model_inputs.append(restecg_input)\n  restecg_calibrator = tfl.layers.CategoricalCalibration(\n      num_buckets=3,\n      output_min=0.0,\n      output_max=lattice_sizes[7] - 1.0,\n      # Categorical monotonicity can be partial order.\n      monotonicities=[(0, 1), (0, 2)],\n      # Categorical calibration layer supports standard Keras regularizers.\n      kernel_regularizer=keras.regularizers.l1_l2(l1=0.001),\n      kernel_initializer='constant',\n  )(\n      restecg_input)\n  lattice_inputs.append(restecg_calibrator)\n\n  # Lattice inputs must be either list of d tensors of rank (batch_size, 1) or\n  # single tensor of rank (batch_size, d) where d is dimensionality of lattice.\n  # Since our embedding layer has size 2 in second dimension - concatenate all\n  # of inputs to create single tensor.\n  lattice_inputs_tensor = keras.layers.concatenate(lattice_inputs, axis=1)\n\n  # Create Lattice layer to nonlineary fuse output of calibrators. Don't forget\n  # to specify 'increasing' monotonicity for any dimension for which\n  # monotonicity is configured regardless of monotonicity direction of those.\n  # This includes partial monotonicity of CategoricalCalibration layer.\n  # Note that making embedding inputs monotonic does not make sense.\n  lattice = tfl.layers.Lattice(\n      lattice_sizes=lattice_sizes,\n      monotonicities=[\n          'none', 'none', 'none', 'increasing', 'increasing', 'increasing',\n          'increasing', 'increasing'\n      ],\n      output_min=0.0,\n      output_max=1.0,\n  )(\n      lattice_inputs_tensor)\n\n  model = keras.models.Model(inputs=model_inputs, outputs=lattice)\n  model.compile(\n      loss=keras.losses.mean_squared_error,\n      optimizer=keras.optimizers.Adagrad(learning_rate=1.0))\n\n  feature_names = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg']\n  features = np.split(\n      training_data_df[feature_names].values.astype(np.float32),\n      indices_or_sections=len(feature_names),\n      axis=1)\n  target = training_data_df[['target']].values.astype(np.float32)\n\n  # Bucketize input for embedding.\n  embedding_bins = np.quantile(\n      features[0],\n      # 10 keypoints will produce 9 dims numbered 1..9 to match embedding input\n      # size of 10.\n      np.linspace(0.0, 1.0, num=10, dtype=np.float32))\n  # Ensure that highest age will get into last bin rather than its own one.\n  embedding_bins[-1] += 1.0\n  features[0] = np.digitize(features[0], bins=embedding_bins)\n\n  model.fit(\n      features,\n      target,\n      batch_size=32,\n      epochs=FLAGS.num_epochs,\n      validation_split=0.2,\n      shuffle=False)\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "examples/keras_sequential_uci_heart.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Example usage of TFL within Keras models.\n\nThis example builds and trains a calibrated lattice model for the UCI heart\ndataset.\n\n\"Calibrated lattice\" is a commonly used architecture for datasets where number\nof input features does not exceed ~15.\n\n\"Calibrated lattice\" assumes every feature being transformed by PWLCalibration\nor CategoricalCalibration layers before nonlineary fusing result of calibration\nwithin a lattice layer.\n\nGenerally when you manually combine TFL layers you should keep track of:\n1) Ensuring that inputs to TFL layers are within expected range.\n  - Input range for PWLCalibration layer is defined by smallest and largest of\n    provided keypoints.\n  - Input range for Lattice layer is [0.0, lattice_sizes[d] - 1.0] for any\n    dimension d.\n  TFL layers can constraint their output to be within desired range. Feeding\n  output of other layers into TFL layers you might want to ensure that something\n  like sigmoid is used to constraint their output range.\n2) Properly configure monotonicity. If your calibration layer is monotonic then\n  corresponding dimension of lattice layer should also be monotonic.\n\nThis example creates a Sequential Keras model and only uses TFL layers. For an\nexample of functional model construction that also use embedding layers see\nkeras_functional_uci_heart.py.\n\nIn order to see how better generalization can be achieved with a properly\nconstrained PWLCalibration layer compared to a vanila embedding layer, compare\ntraining and validation losses of this model with one defined in\nkeras_functional_uci_heart.py\n\n\nNote that the specifics of layer configurations are for demonstration purposes\nand might not result in optimal performance.\n\nExample usage:\nkeras_sequential_uci_heart\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import app\nfrom absl import flags\n\nimport numpy as np\nimport pandas as pd\n\nimport tensorflow as tf\nimport tensorflow_lattice as tfl\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nFLAGS = flags.FLAGS\nflags.DEFINE_integer('num_epochs', 200, 'Number of training epoch.')\n\n\ndef main(_):\n  # UCI Statlog (Heart) dataset.\n  csv_file = keras.utils.get_file(\n      'heart.csv',\n      'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv',\n  )\n  training_data_df = pd.read_csv(csv_file).sample(\n      frac=1.0, random_state=41).reset_index(drop=True)\n\n  # Feature columns.\n  # 0  age\n  # 1  sex\n  # 2  cp        chest pain type (4 values)\n  # 3  trestbps  resting blood pressure\n  # 4  chol      serum cholestoral in mg/dl\n  # 5  fbs       fasting blood sugar > 120 mg/dl\n  # 6  restecg   resting electrocardiographic results (values 0,1,2)\n  # 7  thalach   maximum heart rate achieved\n  # 8  exang     exercise induced angina\n  # 9  oldpeak   ST depression induced by exercise relative to rest\n  # 10 slope     the slope of the peak exercise ST segment\n  # 11 ca        number of major vessels (0-3) colored by flourosopy\n  # 12 thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n\n  # Example slice of training data:\n  #     age  sex  cp  trestbps  chol  fbs  restecg  thalach  exang  oldpeak\n  # 0   63    1   1       145   233    1        2      150      0      2.3\n  # 1   67    1   4       160   286    0        2      108      1      1.5\n  # 2   67    1   4       120   229    0        2      129      1      2.6\n  # 3   37    1   3       130   250    0        0      187      0      3.5\n  # 4   41    0   2       130   204    0        2      172      0      1.4\n  # 5   56    1   2       120   236    0        0      178      0      0.8\n  # 6   62    0   4       140   268    0        2      160      0      3.6\n  # 7   57    0   4       120   354    0        0      163      1      0.6\n  # 8   63    1   4       130   254    0        2      147      0      1.4\n  # 9   53    1   4       140   203    1        2      155      1      3.1\n\n  # Lattice sizes per dimension for Lattice layer.\n  # Lattice layer expects input[i] to be within [0, lattice_sizes[i] - 1.0], so\n  # we need to define lattice sizes ahead of calibration layers so we can\n  # properly specify output range of calibration layers.\n  lattice_sizes = [3, 2, 2, 2, 2, 2, 2]\n\n  # Use ParallelCombination helper layer to group togehter calibration layers\n  # which have to be executed in parallel in order to be able to use Sequential\n  # model. Alternatively you can use functional API.\n  combined_calibrators = tfl.layers.ParallelCombination()\n\n  # Configure calibration layers for every feature:\n\n  # ############### age ###############\n\n  calibrator = tfl.layers.PWLCalibration(\n      # Every PWLCalibration layer must have keypoints of piecewise linear\n      # function specified. Easiest way to specify them is to uniformly cover\n      # entire input range by using numpy.linspace().\n      input_keypoints=np.linspace(training_data_df['age'].min(),\n                                  training_data_df['age'].max(),\n                                  num=5),\n      # You need to ensure that input keypoints have same dtype as layer input.\n      # You can do it by setting dtype here or by providing keypoints in such\n      # format which will be converted to desired tf.dtype by default.\n      dtype=tf.float32,\n      # Output range must correspond to expected lattice input range.\n      output_min=0.0,\n      output_max=lattice_sizes[0] - 1.0,\n      monotonicity='increasing')\n  combined_calibrators.append(calibrator)\n\n  # ############### sex ###############\n\n  # For boolean features simply specify CategoricalCalibration layer with 2\n  # buckets.\n  calibrator = tfl.layers.CategoricalCalibration(\n      num_buckets=2,\n      output_min=0.0,\n      output_max=lattice_sizes[1] - 1.0,\n      # Initializes all outputs to (output_min + output_max) / 2.0.\n      kernel_initializer='constant')\n  combined_calibrators.append(calibrator)\n\n  # ############### cp ###############\n\n  calibrator = tfl.layers.PWLCalibration(\n      # Here instead of specifying dtype of layer we convert keypoints into\n      # np.float32.\n      input_keypoints=np.linspace(1, 4, num=4, dtype=np.float32),\n      output_min=0.0,\n      output_max=lattice_sizes[2] - 1.0,\n      monotonicity='increasing',\n      # You can specify TFL regularizers as tuple ('regularizer name', l1, l2).\n      kernel_regularizer=('hessian', 0.0, 1e-4))\n  combined_calibrators.append(calibrator)\n\n  # ############### trestbps ###############\n\n  calibrator = tfl.layers.PWLCalibration(\n      # Alternatively to uniform keypoints you might want to use quantiles as\n      # keypoints.\n      input_keypoints=np.quantile(\n          training_data_df['trestbps'], np.linspace(0.0, 1.0, num=5)),\n      dtype=tf.float32,\n      # Together with quantile keypoints you might want to initialize piecewise\n      # linear function to have 'equal_slopes' in order for output of layer\n      # after initialization to preserve original distribution.\n      kernel_initializer='equal_slopes',\n      output_min=0.0,\n      output_max=lattice_sizes[3] - 1.0,\n      # You might consider clamping extreme inputs of the calibrator to output\n      # bounds.\n      clamp_min=True,\n      clamp_max=True,\n      monotonicity='increasing')\n  combined_calibrators.append(calibrator)\n\n  # ############### chol ###############\n\n  calibrator = tfl.layers.PWLCalibration(\n      # Explicit input keypoint initialization.\n      input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n      dtype=tf.float32,\n      output_min=0.0,\n      output_max=lattice_sizes[4] - 1.0,\n      # Monotonicity of calibrator can be 'decreasing'. Note that corresponding\n      # lattice dimension must have 'increasing' monotonicity regardless of\n      # monotonicity direction of calibrator.\n      # It's not some weird configuration hack. It's just how math works :)\n      monotonicity='decreasing',\n      # Convexity together with decreasing monotonicity result in diminishing\n      # return constraint.\n      convexity='convex',\n      # You can specify list of regularizers. You are not limited to TFL\n      # regularizrs. Feel free to use any :)\n      kernel_regularizer=[('laplacian', 0.0, 1e-4),\n                          keras.regularizers.l1_l2(l1=0.001)])\n  combined_calibrators.append(calibrator)\n\n  # ############### fbs ###############\n\n  calibrator = tfl.layers.CategoricalCalibration(\n      num_buckets=2,\n      output_min=0.0,\n      output_max=lattice_sizes[5] - 1.0,\n      # For categorical calibration layer monotonicity is specified for pairs\n      # of indices of categories. Output for first category in pair will be\n      # smaller than output for second category.\n      #\n      # Don't forget to set monotonicity of corresponding dimension of Lattice\n      # layer to 'increasing'.\n      monotonicities=[(0, 1)],\n      # This initializer is identical to default one('uniform'), but has fixed\n      # seed in order to simplify experimentation.\n      kernel_initializer=keras.initializers.RandomUniform(\n          minval=0.0, maxval=lattice_sizes[5] - 1.0, seed=1))\n  combined_calibrators.append(calibrator)\n\n  # ############### restecg ###############\n\n  calibrator = tfl.layers.CategoricalCalibration(\n      num_buckets=3,\n      output_min=0.0,\n      output_max=lattice_sizes[6] - 1.0,\n      # Categorical monotonicity can be partial order.\n      monotonicities=[(0, 1), (0, 2)],\n      # Categorical calibration layer supports standard Keras regularizers.\n      kernel_regularizer=keras.regularizers.l1_l2(l1=0.001),\n      kernel_initializer='constant')\n  combined_calibrators.append(calibrator)\n\n  # Create Lattice layer to nonlineary fuse output of calibrators. Don't forget\n  # to specify monotonicity 'increasing' for any dimension which calibrator is\n  # monotonic regardless of monotonicity direction of calibrator. This includes\n  # partial monotonicity of CategoricalCalibration layer.\n  lattice = tfl.layers.Lattice(\n      lattice_sizes=lattice_sizes,\n      monotonicities=['increasing', 'none', 'increasing', 'increasing',\n                      'increasing', 'increasing', 'increasing'],\n      output_min=0.0,\n      output_max=1.0)\n\n  model = keras.models.Sequential()\n  # We have just 2 layer as far as Sequential model is concerned.\n  # PWLConcatenate layer takes care of grouping calibrators.\n  model.add(combined_calibrators)\n  model.add(lattice)\n  model.compile(loss=keras.losses.mean_squared_error,\n                optimizer=keras.optimizers.Adagrad(learning_rate=1.0))\n\n  features = training_data_df[\n      ['age', 'sex', 'cp',\n       'trestbps', 'chol', 'fbs', 'restecg']].values.astype(np.float32)\n  target = training_data_df[['target']].values.astype(np.float32)\n\n  model.fit(features,\n            target,\n            batch_size=32,\n            epochs=FLAGS.num_epochs,\n            validation_split=0.2,\n            shuffle=False)\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright 2018 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\"); you may not\n# use this file except in compliance with the License. You may obtain a copy of\n# the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n# License for the specific language governing permissions and limitations under\n# the License.\n# ==============================================================================\n\"\"\"Package setup script for TensorFlow Lattice library.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport datetime\nimport sys\n\nfrom setuptools import find_packages\nfrom setuptools import setup\n\n# This version number should always be that of the *next* (unreleased) version.\n# Immediately after uploading a package to PyPI, you should increment the\n# version number and push to gitHub.\n__version__ = \"2.1.1\"\n\nif \"--release\" in sys.argv:\n  sys.argv.remove(\"--release\")\n  _name = \"tensorflow_lattice\"\nelse:\n  # Build a nightly package by default.\n  _name = \"tensorflow_lattice_nightly\"\n  __version__ += datetime.datetime.now().strftime(\".dev%Y%m%d\")\n\n_install_requires = [\n    \"absl-py\",\n    \"numpy\",\n    \"pandas\",\n    \"six\",\n    \"scikit-learn\",\n    \"matplotlib\",\n    \"graphviz\",\n    \"tf-keras\",\n]\n\n# Part of the visualization code uses colabtools and IPython libraries. These\n# are not added as hard requirements as they are mainly used in jupyter/colabs.\n\n_extras_require = {\n    \"tensorflow\": \"tensorflow>=1.15\",\n}\n\n_classifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Education\",\n    \"Intended Audience :: Science/Research\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Operating System :: OS Independent\",\n    \"Programming Language :: Python\",\n    \"Programming Language :: Python :: 2\",\n    \"Programming Language :: Python :: 3\",\n    \"Topic :: Scientific/Engineering\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Scientific/Engineering :: Mathematics\",\n    \"Topic :: Software Development\",\n    \"Topic :: Software Development :: Libraries\",\n    \"Topic :: Software Development :: Libraries :: Python Modules\",\n]\n\n_description = (\n    \"A library that implements optionally monotonic lattice based models.\")\n_long_description = \"\"\"\\\nTensorFlow Lattice is a library that implements fast-to-evaluate and\ninterpretable (optionally monotonic) lattice based models, which are also known\nas *interpolated look-up tables*. The library includes a collection of Keras\nlayers for lattices and feature calibration that can be composed into custom\nmodels or used inside generic premade models.\n\"\"\"\n\nsetup(\n    name=_name,\n    version=__version__,\n    author=\"Google Inc.\",\n    author_email=\"no-reply@google.com\",\n    license=\"Apache 2.0\",\n    classifiers=_classifiers,\n    install_requires=_install_requires,\n    extras_require=_extras_require,\n    packages=find_packages(),\n    include_package_data=True,\n    description=_description,\n    long_description=_long_description,\n    long_description_content_type=\"text/markdown\",\n    keywords=\"tensorflow lattice calibration machine learning\",\n    url=(\n        \"https://github.com/tensorflow/lattice\"\n    ),\n)\n"
  },
  {
    "path": "tensorflow_lattice/BUILD",
    "content": "# Copyright 2017 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\npackage(\n    default_visibility = [\n        \"//visibility:public\",\n    ],\n)\n\nlicenses([\"notice\"])\n\nexports_files([\"LICENSE\"])\n\npy_library(\n    name = \"tensorflow_lattice\",\n    srcs = [\n        \"__init__.py\",\n        \"layers/__init__.py\",\n    ],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \"//tensorflow_lattice/python:aggregation_layer\",\n        \"//tensorflow_lattice/python:categorical_calibration_layer\",\n        \"//tensorflow_lattice/python:categorical_calibration_lib\",\n        \"//tensorflow_lattice/python:cdf_layer\",\n        \"//tensorflow_lattice/python:conditional_cdf\",\n        \"//tensorflow_lattice/python:conditional_pwl_calibration\",\n        \"//tensorflow_lattice/python:configs\",\n        \"//tensorflow_lattice/python:kronecker_factored_lattice_layer\",\n        \"//tensorflow_lattice/python:kronecker_factored_lattice_lib\",\n        \"//tensorflow_lattice/python:lattice_layer\",\n        \"//tensorflow_lattice/python:lattice_lib\",\n        \"//tensorflow_lattice/python:linear_layer\",\n        \"//tensorflow_lattice/python:linear_lib\",\n        \"//tensorflow_lattice/python:model_info\",\n        \"//tensorflow_lattice/python:parallel_combination_layer\",\n        \"//tensorflow_lattice/python:premade\",\n        \"//tensorflow_lattice/python:premade_lib\",\n        \"//tensorflow_lattice/python:pwl_calibration_layer\",\n        \"//tensorflow_lattice/python:pwl_calibration_lib\",\n        \"//tensorflow_lattice/python:rtl_layer\",\n        \"//tensorflow_lattice/python:test_utils\",\n        \"//tensorflow_lattice/python:utils\",\n    ],\n)\n"
  },
  {
    "path": "tensorflow_lattice/__init__.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tensorflow Lattice Library.\n\nThis package provides functions and classes for lattice modeling.\n\"\"\"\n\nfrom __future__ import absolute_import\n\nimport tensorflow_lattice.layers\nfrom tensorflow_lattice.python import aggregation_layer\nfrom tensorflow_lattice.python import categorical_calibration_layer\nfrom tensorflow_lattice.python import categorical_calibration_lib\nfrom tensorflow_lattice.python import cdf_layer\nfrom tensorflow_lattice.python import conditional_cdf\nfrom tensorflow_lattice.python import conditional_pwl_calibration\nfrom tensorflow_lattice.python import configs\nfrom tensorflow_lattice.python import kronecker_factored_lattice_layer\nfrom tensorflow_lattice.python import kronecker_factored_lattice_lib\nfrom tensorflow_lattice.python import lattice_layer\nfrom tensorflow_lattice.python import lattice_lib\nfrom tensorflow_lattice.python import linear_layer\nfrom tensorflow_lattice.python import linear_lib\nfrom tensorflow_lattice.python import model_info\nfrom tensorflow_lattice.python import parallel_combination_layer\nfrom tensorflow_lattice.python import premade\nfrom tensorflow_lattice.python import premade_lib\nfrom tensorflow_lattice.python import pwl_calibration_layer\nfrom tensorflow_lattice.python import pwl_calibration_lib\nfrom tensorflow_lattice.python import test_utils\nfrom tensorflow_lattice.python import utils\n"
  },
  {
    "path": "tensorflow_lattice/layers/__init__.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"'layers' namespace for TFL layers.\"\"\"\n\nfrom tensorflow_lattice.python.aggregation_layer import Aggregation\nfrom tensorflow_lattice.python.categorical_calibration_layer import CategoricalCalibration\nfrom tensorflow_lattice.python.cdf_layer import CDF\nfrom tensorflow_lattice.python.kronecker_factored_lattice_layer import KroneckerFactoredLattice\nfrom tensorflow_lattice.python.lattice_layer import Lattice\nfrom tensorflow_lattice.python.linear_layer import Linear\nfrom tensorflow_lattice.python.parallel_combination_layer import ParallelCombination\nfrom tensorflow_lattice.python.pwl_calibration_layer import PWLCalibration\nfrom tensorflow_lattice.python.rtl_layer import RTL\n"
  },
  {
    "path": "tensorflow_lattice/python/BUILD",
    "content": "# Copyright 2019 The TensorFlow Lattice Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\nload(\"//third_party/bazel_rules/rules_python/python:py_library.bzl\", \"py_library\")\nload(\"//third_party/bazel_rules/rules_python/python:py_test.bzl\", \"py_test\")\n\npackage(\n    default_visibility = [\n        \"//tensorflow_lattice:__subpackages__\",\n    ],\n)\n\nlicenses([\"notice\"])\n\n# Build rules are alphabetized. Please add new rules alphabetically\n# to maintain the ordering.\npy_library(\n    name = \"aggregation_layer\",\n    srcs = [\"aggregation_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        # tensorflow dep,\n    ],\n)\n\npy_test(\n    name = \"aggregation_test\",\n    srcs = [\"aggregation_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":aggregation_layer\",\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"categorical_calibration_layer\",\n    srcs = [\"categorical_calibration_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":categorical_calibration_lib\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"categorical_calibration_lib\",\n    srcs = [\"categorical_calibration_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":internal_utils\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"categorical_calibration_test\",\n    size = \"large\",\n    srcs = [\"categorical_calibration_test.py\"],\n    python_version = \"PY3\",\n    # shard_count = 4,\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":categorical_calibration_layer\",\n        \":parallel_combination_layer\",\n        \":test_utils\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"cdf_layer\",\n    srcs = [\"cdf_layer.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":utils\",\n        # tensorflow dep,\n    ],\n)\n\npy_test(\n    name = \"cdf_test\",\n    size = \"large\",\n    srcs = [\"cdf_test.py\"],\n    python_version = \"PY3\",\n    # shard_count = 12,\n    srcs_version = \"PY3\",\n    deps = [\n        \":cdf_layer\",\n        \":test_utils\",\n        \":utils\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"configs\",\n    srcs = [\"configs.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        # absl/logging dep,\n        # tensorflow dep,\n    ],\n)\n\npy_test(\n    name = \"configs_test\",\n    size = \"small\",\n    srcs = [\"configs_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":categorical_calibration_layer\",\n        \":configs\",\n        \":lattice_layer\",\n        \":linear_layer\",\n        \":premade\",\n        \":pwl_calibration_layer\",\n        # absl/logging dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"internal_utils\",\n    srcs = [\"internal_utils.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        # tensorflow dep,\n    ],\n)\n\npy_test(\n    name = \"internal_utils_test\",\n    srcs = [\"internal_utils_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":internal_utils\",\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"kronecker_factored_lattice_layer\",\n    srcs = [\"kronecker_factored_lattice_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":kronecker_factored_lattice_lib\",\n        \":utils\",\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"kronecker_factored_lattice_lib\",\n    srcs = [\"kronecker_factored_lattice_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":utils\",\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_test(\n    name = \"kronecker_factored_lattice_test\",\n    size = \"large\",\n    srcs = [\"kronecker_factored_lattice_test.py\"],\n    python_version = \"PY3\",\n    # shard_count = 12,\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":kronecker_factored_lattice_layer\",\n        \":kronecker_factored_lattice_lib\",\n        \":test_utils\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"lattice_layer\",\n    srcs = [\"lattice_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":lattice_lib\",\n        \":pwl_calibration_layer\",\n        \":utils\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"lattice_lib\",\n    srcs = [\"lattice_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":utils\",\n        # absl/logging dep,\n        # numpy dep,\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"lattice_test\",\n    size = \"large\",\n    srcs = [\"lattice_test.py\"],\n    python_version = \"PY3\",\n    # shard_count = 12,\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":lattice_layer\",\n        \":test_utils\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"linear_layer\",\n    srcs = [\"linear_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":linear_lib\",\n        \":utils\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"linear_lib\",\n    srcs = [\"linear_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":internal_utils\",\n        \":utils\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"linear_test\",\n    size = \"large\",\n    srcs = [\"linear_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":linear_layer\",\n        \":test_utils\",\n        \":utils\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"model_info\",\n    srcs = [\"model_info.py\"],\n    srcs_version = \"PY2AND3\",\n)\n\npy_library(\n    name = \"parallel_combination_layer\",\n    srcs = [\"parallel_combination_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":categorical_calibration_layer\",\n        \":lattice_layer\",\n        \":linear_layer\",\n        \":pwl_calibration_layer\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"parallel_combination_test\",\n    size = \"large\",\n    srcs = [\"parallel_combination_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":lattice_layer\",\n        \":parallel_combination_layer\",\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"premade\",\n    srcs = [\"premade.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":aggregation_layer\",\n        \":categorical_calibration_layer\",\n        \":configs\",\n        \":kronecker_factored_lattice_layer\",\n        \":lattice_layer\",\n        \":parallel_combination_layer\",\n        \":premade_lib\",\n        \":pwl_calibration_layer\",\n        # absl/logging dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"premade_lib\",\n    srcs = [\"premade_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":aggregation_layer\",\n        \":categorical_calibration_layer\",\n        \":configs\",\n        \":kronecker_factored_lattice_layer\",\n        \":kronecker_factored_lattice_lib\",\n        \":lattice_layer\",\n        \":lattice_lib\",\n        \":linear_layer\",\n        \":pwl_calibration_layer\",\n        \":rtl_layer\",\n        \":utils\",\n        # absl/logging dep,\n        # numpy dep,\n        # six dep,\n        # tensorflow dep,\n    ],\n)\n\npy_test(\n    name = \"premade_test\",\n    size = \"large\",\n    srcs = [\"premade_test.py\"],\n    python_version = \"PY3\",\n    # shard_count = 10,\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":configs\",\n        \":premade\",\n        \":premade_lib\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"pwl_calibration_layer\",\n    srcs = [\"pwl_calibration_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":pwl_calibration_lib\",\n        \":utils\",\n        # absl/logging dep,\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"pwl_calibration_lib\",\n    srcs = [\"pwl_calibration_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":utils\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"pwl_calibration_test\",\n    size = \"large\",\n    srcs = [\"pwl_calibration_test.py\"],\n    python_version = \"PY3\",\n    # shard_count = 12,\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":parallel_combination_layer\",\n        \":pwl_calibration_layer\",\n        \":test_utils\",\n        \":utils\",\n        # absl/logging dep,\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"rtl_layer\",\n    srcs = [\"rtl_layer.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":kronecker_factored_lattice_layer\",\n        \":lattice_layer\",\n        \":rtl_lib\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"rtl_lib\",\n    srcs = [\"rtl_lib.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        # six dep,\n    ],\n)\n\npy_test(\n    name = \"rtl_test\",\n    size = \"large\",\n    srcs = [\"rtl_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":linear_layer\",\n        \":pwl_calibration_layer\",\n        \":rtl_layer\",\n        # absl/testing:parameterized dep,\n        # numpy dep,\n        # tensorflow dep,\n    ],\n)\n\npy_library(\n    name = \"test_utils\",\n    srcs = [\"test_utils.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        # absl/logging dep,\n        # numpy dep,\n    ],\n)\n\npy_library(\n    name = \"utils\",\n    srcs = [\"utils.py\"],\n    srcs_version = \"PY2AND3\",\n    deps = [\n        # six dep,\n    ],\n)\n\npy_library(\n    name = \"conditional_pwl_calibration\",\n    srcs = [\"conditional_pwl_calibration.py\"],\n    deps = [\n        # numpy dep,\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_library(\n    name = \"conditional_cdf\",\n    srcs = [\"conditional_cdf.py\"],\n    deps = [\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"conditional_cdf_test\",\n    srcs = [\"conditional_cdf_test.py\"],\n    deps = [\n        \":conditional_cdf\",\n        # absl/testing:parameterized dep,\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"conditional_pwl_calibration_test\",\n    srcs = [\"conditional_pwl_calibration_test.py\"],\n    deps = [\n        \":conditional_pwl_calibration\",\n        # tensorflow:tensorflow_no_contrib dep,\n    ],\n)\n\npy_test(\n    name = \"utils_test\",\n    srcs = [\"utils_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY2AND3\",\n    deps = [\n        \":utils\",\n        # absl/testing:parameterized dep,\n        # tensorflow dep,\n    ],\n)\n"
  },
  {
    "path": "tensorflow_lattice/python/__init__.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TensorFlow Lattice python package.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n"
  },
  {
    "path": "tensorflow_lattice/python/aggregation_layer.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Layer which represents aggregation function.\n\nSee class level comment.\n\nThis layer applies the provided model to the ragged input tensor and aggregates\nthe results.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass Aggregation(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Layer which represents an aggregation function.\n\n  Calls the model on each of the ragged dimensions and takes the mean.\n\n  Input shape:\n  A list or dictionary with num_input_dims Rank-2 ragged tensors with\n  shape: (batch_size, ?)\n\n  Output shape:\n  Rank-2 tensor with shape: (batch_size, 1)\n\n  Attributes:\n    - All `__init__ `arguments.\n\n  Example:\n\n  ```python\n  model = keras.Model(inputs=inputs, outputs=outputs)\n  layer = tfl.layers.Aggregation(model)\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, model, **kwargs):\n    \"\"\"initializes an instance of `Aggregation`.\n\n    Args:\n      model: A keras.Model instance.\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: if model is not at `keras.Model` instance.\n    \"\"\"\n    if not isinstance(model, keras.Model):\n      raise ValueError('Model must be a keras.Model instance.')\n    super(Aggregation, self).__init__(**kwargs)\n    # This flag enables inputs to be Ragged Tensors\n    self._supports_ragged_inputs = True\n    self.model = model\n\n  def call(self, x):\n    \"\"\"Standard Keras call() method.\"\"\"\n    return tf.reduce_mean(tf.ragged.map_flat_values(self.model, x), axis=1)\n\n  def get_config(self):\n    \"\"\"Standard Keras get_config() method.\"\"\"\n    config = super(Aggregation, self).get_config().copy()\n    config.update(\n        {'model': keras.utils.legacy.serialize_keras_object(self.model)}\n    )\n    return config\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    model = keras.utils.legacy.deserialize_keras_object(\n        config.pop('model'), custom_objects=custom_objects\n    )\n    return cls(model, **config)\n"
  },
  {
    "path": "tensorflow_lattice/python/aggregation_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Tensorflow Lattice premade.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_lattice.python import aggregation_layer\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\ntest_input = [\n    tf.ragged.constant([[1, 2], [1, 2, 3], [3]]),\n    tf.ragged.constant([[4, 5], [4, 4, 4], [6]]),\n    tf.ragged.constant([[1, 6], [5, 5, 5], [9]])\n]\n\nexpected_output = tf.constant([32, 40, 162])\n\n\nclass AggregationTest(tf.test.TestCase):\n\n  def testAggregationLayer(self):\n    # First we test our assertion that the model must be a keras.Model\n    with self.assertRaisesRegex(ValueError,\n                                'Model must be a keras.Model instance.'):\n      aggregation_layer.Aggregation(None)\n    # Now let's make sure our layer aggregates properly.\n    inputs = [keras.Input(shape=()) for _ in range(len(test_input))]\n    output = keras.layers.multiply(inputs)\n    model = keras.Model(inputs=inputs, outputs=output)\n    agg_layer = aggregation_layer.Aggregation(model)\n    self.assertAllEqual(agg_layer(test_input), expected_output)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/categorical_calibration_layer.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Categorical calibration layer with monotonicity and bound constraints.\n\nKeras implementation of tensorflow lattice categorical calibration layer. This\nlayer takes single or multi-dimensional input and transforms it using lookup\ntables satisfying monotonicity and bounds constraints if specified.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\nfrom . import categorical_calibration_lib\n\nDEFAULT_INPUT_VALUE_NAME = \"default_input_value\"\nCATEGORICAL_CALIBRATION_KERNEL_NAME = \"categorical_calibration_kernel\"\n\n# TODO: implement variation/variance regularizer.\n\n\nclass CategoricalCalibration(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Categorical calibration layer with monotonicity and bound constraints.\n\n  This layer takes input of shape `(batch_size, units)` or `(batch_size, 1)` and\n  transforms it using `units` number of lookup tables satisfying monotonicity\n  and bounds constraints if specified. If multi dimensional input is provided,\n  each output will be for the corresponding input, otherwise all calibration\n  functions will act on the same input. All units share the same layer\n  configuration, but each one has their separate set of trained parameters.\n\n  Input shape:\n  Rank-2 tensor with shape:  `(batch_size, units)` or `(batch_size, 1)`.\n\n  Output shape:\n  If units > 1 and split_outputs is True, a length `units` list of Rank-2\n    tensors with shape `(batch_size, 1)`. Otherwise, a Rank-2 tensor with shape:\n    `(batch_size, units)`\n\n  Attributes:\n    - All `__init__` args.\n    kernel: TF variable of shape `(batch_size, units)` which stores the lookup\n    table.\n\n  Example:\n\n  ```python\n  calibrator = tfl.layers.CategoricalCalibration(\n      # Number of categories.\n      num_buckets=3,\n      # Output can be bounded.\n      output_min=0.0,\n      output_max=1.0,\n      # For categorical calibration layer monotonicity is specified for pairs of\n      # indices of categories. Output for first category in pair will be less\n      # than or equal to output for second category.\n      monotonicities=[(0, 1), (0, 2)])\n  ```\n\n  Usage with functional models:\n\n  ```python\n  input_feature = keras.layers.Input(shape=[1])\n  calibrated_feature = tfl.layers.CategoricalCalibration(\n      num_buckets=3,\n      output_min=0.0,\n      output_max=1.0,\n      monotonicities=[(0, 1), (0, 2)],\n  )(feature)\n  ...\n  model = keras.models.Model(\n      inputs=[input_feature, ...],\n      outputs=...)\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               num_buckets,\n               units=1,\n               output_min=None,\n               output_max=None,\n               monotonicities=None,\n               kernel_initializer=\"uniform\",\n               kernel_regularizer=None,\n               default_input_value=None,\n               split_outputs=False,\n               **kwargs):\n    # pyformat: disable\n    \"\"\"Initializes a `CategoricalCalibration` instance.\n\n    Args:\n      num_buckets: Number of categories.\n      units: Output dimension of the layer. See class comments for details.\n      output_min: Minimum output of calibrator.\n      output_max: Maximum output of calibrator.\n      monotonicities: List of pairs with `(i, j)` indices indicating `output(i)`\n        should be less than or equal to `output(j)`.\n      kernel_initializer: None or one of:\n        - `'uniform'`: If `output_min` and `output_max` are provided initial\n          values will be uniformly sampled from `[output_min, output_max]`\n          range.\n        - `'constant'`: If `output_min` and `output_max` are provided all output\n          values will be initlized to the constant\n          `(output_min + output_max) / 2`.\n        - Any Keras initializer object.\n      kernel_regularizer: None or single element or list of any Keras\n        regularizer objects.\n      default_input_value: If set, all inputs which are equal to this value will\n        be treated as default and mapped to the last bucket.\n      split_outputs: Whether to split the output tensor into a list of\n        outputs for each unit. Ignored if units < 2.\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: If layer hyperparameters are invalid.\n    \"\"\"\n    # pyformat: enable\n    dtype = kwargs.pop(\"dtype\", tf.float32)  # output dtype\n    super(CategoricalCalibration, self).__init__(dtype=dtype, **kwargs)\n\n    categorical_calibration_lib.verify_hyperparameters(\n        num_buckets=num_buckets,\n        output_min=output_min,\n        output_max=output_max,\n        monotonicities=monotonicities)\n    self.num_buckets = num_buckets\n    self.units = units\n    self.output_min = output_min\n    self.output_max = output_max\n    self.monotonicities = monotonicities\n    if output_min is not None and output_max is not None:\n      if kernel_initializer == \"constant\":\n        kernel_initializer = keras.initializers.Constant(\n            (output_min + output_max) / 2)\n      elif kernel_initializer == \"uniform\":\n        kernel_initializer = keras.initializers.RandomUniform(\n            output_min, output_max)\n    self.kernel_initializer = keras.initializers.get(kernel_initializer)\n    self.kernel_regularizer = []\n    if kernel_regularizer:\n      if callable(kernel_regularizer):\n        kernel_regularizer = [kernel_regularizer]\n      for reg in kernel_regularizer:\n        self.kernel_regularizer.append(keras.regularizers.get(reg))\n    self.default_input_value = default_input_value\n    self.split_outputs = split_outputs\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    if (self.output_min is not None or self.output_max is not None or\n        self.monotonicities):\n      constraints = CategoricalCalibrationConstraints(\n          output_min=self.output_min,\n          output_max=self.output_max,\n          monotonicities=self.monotonicities)\n    else:\n      constraints = None\n\n    if not self.kernel_regularizer:\n      kernel_reg = None\n    elif len(self.kernel_regularizer) == 1:\n      kernel_reg = self.kernel_regularizer[0]\n    else:\n      # Keras interface assumes only one regularizer, so summ all regularization\n      # losses which we have.\n      kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer])\n\n    # categorical calibration layer kernel is units-column matrix with value of\n    # output(i) = self.kernel[i]. Default value converted to the last index.\n    self.kernel = self.add_weight(\n        CATEGORICAL_CALIBRATION_KERNEL_NAME,\n        shape=[self.num_buckets, self.units],\n        initializer=self.kernel_initializer,\n        regularizer=kernel_reg,\n        constraint=constraints,\n        dtype=self.dtype)\n\n    if self.kernel_regularizer and not tf.executing_eagerly():\n      # Keras has its own mechanism to handle regularization losses which\n      # does not use GraphKeys, but we want to also add losses to graph keys so\n      # they are easily accessable when layer is being used outside of Keras.\n      # Adding losses to GraphKeys will not interfer with Keras.\n      for reg in self.kernel_regularizer:\n        tf.compat.v1.add_to_collection(\n            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg(self.kernel))\n\n    super(CategoricalCalibration, self).build(input_shape)\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    if inputs.dtype not in [tf.uint8, tf.int32, tf.int64]:\n      inputs = tf.cast(inputs, dtype=tf.int32)\n\n    if self.default_input_value is not None:\n      default_input_value_tensor = tf.constant(\n          int(self.default_input_value),\n          dtype=inputs.dtype,\n          name=DEFAULT_INPUT_VALUE_NAME)\n      replacement = tf.zeros_like(inputs) + (self.num_buckets - 1)\n      inputs = tf.where(\n          tf.equal(inputs, default_input_value_tensor), replacement, inputs)\n\n    # We can't use tf.gather_nd(self.kernel, inputs) as it doesn't support\n    # constraints (constraint functions are not supported for IndexedSlices).\n    # Instead we use matrix multiplication by one-hot encoding of the index.\n    if self.units == 1:\n      # This can be slightly faster as it uses matmul.\n      return tf.matmul(\n          tf.one_hot(tf.squeeze(inputs, axis=[-1]), depth=self.num_buckets),\n          self.kernel)\n    result = tf.reduce_sum(\n        tf.one_hot(inputs, axis=1, depth=self.num_buckets) * self.kernel,\n        axis=1)\n\n    if self.split_outputs:\n      result = tf.split(result, self.units, axis=1)\n\n    return result\n\n  def compute_output_shape(self, input_shape):\n    \"\"\"Standard Keras compute_output_shape() method.\"\"\"\n    del input_shape\n    if self.units > 1 and self.split_outputs:\n      return [(None, 1)] * self.units\n    else:\n      return (None, self.units)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"num_buckets\": self.num_buckets,\n        \"units\": self.units,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"monotonicities\": self.monotonicities,\n        \"kernel_initializer\":\n            keras.initializers.serialize(\n                self.kernel_initializer, use_legacy_format=True),\n        \"kernel_regularizer\":\n            [keras.regularizers.serialize(r, use_legacy_format=True)\n             for r in self.kernel_regularizer],\n        \"default_input_value\": self.default_input_value,\n        \"split_outputs\": self.split_outputs,\n    }  # pyformat: disable\n    config.update(super(CategoricalCalibration, self).get_config())\n    return config\n\n  def assert_constraints(self, eps=1e-6):\n    \"\"\"Asserts that layer weights satisfy all constraints.\n\n    In graph mode builds and returns list of assertion ops. Note that ops will\n    be created at the moment when this function is being called.\n    In eager mode directly executes assertions.\n\n    Args:\n      eps: Allowed constraints violation.\n\n    Returns:\n      List of assertion ops in graph mode or immediately asserts in eager mode.\n    \"\"\"\n    return categorical_calibration_lib.assert_constraints(\n        weights=self.kernel,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        monotonicities=self.monotonicities,\n        eps=eps)\n\n\nclass CategoricalCalibrationConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Monotonicity and bounds constraints for categorical calibration layer.\n\n  Updates the weights of CategoricalCalibration layer to satify bound and\n  monotonicity constraints. The update is an approximate L2 projection into the\n  constrained parameter space.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, output_min=None, output_max=None, monotonicities=None):\n    \"\"\"Initializes an instance of `CategoricalCalibrationConstraints`.\n\n    Args:\n      output_min: Minimum possible output of categorical function.\n      output_max: Maximum possible output of categorical function.\n      monotonicities: Monotonicities of CategoricalCalibration layer.\n    \"\"\"\n    categorical_calibration_lib.verify_hyperparameters(\n        output_min=output_min,\n        output_max=output_max,\n        monotonicities=monotonicities)\n    self.monotonicities = monotonicities\n    self.output_min = output_min\n    self.output_max = output_max\n\n  def __call__(self, w):\n    \"\"\"Applies constraints to w.\"\"\"\n    return categorical_calibration_lib.project(\n        weights=w,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        monotonicities=self.monotonicities)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"monotonicities\": self.monotonicities,\n    }  # pyformat: disable\n"
  },
  {
    "path": "tensorflow_lattice/python/categorical_calibration_lib.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Helpers and computations of categorical calibration layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom . import internal_utils\nimport tensorflow as tf\n\n\ndef project(weights, output_min, output_max, monotonicities):\n  \"\"\"Monotonicity/bounds constraints implementation for categorical calibration.\n\n  Returns the approximate L2 projection of the CategoricalCalibration weights\n  into the constrained parameter space.\n\n  Args:\n    weights: Tensor which represents weights of Categorical calibration layer.\n    output_min: Lower bound constraint on weights.\n    output_max: Upper bound constraint on weights.\n    monotonicities: List of pair of indices `(i, j)`, indicating constraint\n      `weight[i] <= weight[j]`.\n\n  Returns:\n    Projected `weights` tensor.\n\n  Raises:\n    ValueError: If monotonicities are not of the correct format or are circular.\n  \"\"\"\n  num_buckets = weights.shape[0]\n  verify_hyperparameters(\n      num_buckets=num_buckets,\n      output_min=output_min,\n      output_max=output_max,\n      monotonicities=monotonicities)\n\n  projected_weights = weights\n\n  if monotonicities:\n    projected_weights = (\n        internal_utils.approximately_project_categorical_partial_monotonicities(\n            projected_weights, monotonicities))\n\n  if output_min is not None:\n    projected_weights = tf.maximum(projected_weights, output_min)\n  if output_max is not None:\n    projected_weights = tf.minimum(projected_weights, output_max)\n  return projected_weights\n\n\ndef assert_constraints(weights,\n                       output_min,\n                       output_max,\n                       monotonicities,\n                       debug_tensors=None,\n                       eps=1e-6):\n  \"\"\"Asserts that `weights` satisfiy constraints.\n\n  Args:\n    weights: Tensor which represents weights of Categorical calibration layer.\n    output_min: Lower bound constraint on weights.\n    output_max: Upper bound constraint on weights.\n    monotonicities: List of pair of indices `(i, j)`, indicating constraint\n      `weight[i] <= weight[j]`.\n    debug_tensors: None or list of anything convertible to tensor (for example\n      tensors or strings) which will be printed in case of constraints\n      violation.\n    eps: Allowed constraints violation.\n\n  Returns:\n    List of assertion ops in graph mode or immideately asserts in eager mode.\n  \"\"\"\n  num_buckets = weights.shape[0]\n  verify_hyperparameters(\n      num_buckets=num_buckets,\n      output_min=output_min,\n      output_max=output_max,\n      monotonicities=monotonicities)\n\n  info = [\"Outputs: \", weights, \"Epsilon: \", eps]\n  if debug_tensors:\n    info += debug_tensors\n  asserts = []\n\n  if output_min is not None:\n    min_output = tf.reduce_min(weights)\n    asserts.append(\n        tf.Assert(\n            min_output >= output_min - eps,\n            data=[\"Lower bound violation.\", \"output_min:\", output_min] + info,\n            summarize=num_buckets))\n\n  if output_max is not None:\n    max_output = tf.reduce_max(weights)\n    asserts.append(\n        tf.Assert(\n            max_output <= output_max + eps,\n            data=[\"Upper bound violation.\", \"output_max:\", output_max] + info,\n            summarize=num_buckets))\n\n  if monotonicities:\n    left = tf.gather_nd(weights, [[i] for (i, j) in monotonicities])\n    right = tf.gather_nd(weights, [[j] for (i, j) in monotonicities])\n    asserts.append(\n        tf.Assert(\n            tf.reduce_min(left - right) < eps,\n            data=[\"Monotonicity violation.\", \"monotonicities:\", monotonicities]\n            + info,\n            summarize=num_buckets))\n\n  return asserts\n\n\ndef verify_hyperparameters(num_buckets=None,\n                           output_min=None,\n                           output_max=None,\n                           monotonicities=None):\n  \"\"\"Verifies that all given hyperparameters are consistent.\n\n  See `tfl.layers.CategoricalCalibration` class level comment for detailes.\n\n  Args:\n    num_buckets: `num_buckets` of CategoricalCalibration layer.\n    output_min: `smallest output` of CategoricalCalibration layer.\n    output_max: `largest output` of CategoricalCalibration layer.\n    monotonicities: `monotonicities` of CategoricalCalibration layer.\n\n  Raises:\n    ValueError: If parameters are incorrect or inconsistent.\n  \"\"\"\n  if output_min is not None and output_max is not None:\n    if output_max < output_min:\n      raise ValueError(\n          \"If specified output_max must be greater than output_min. \"\n          \"They are: ({}, {})\".format(output_min, output_max))\n\n  if monotonicities:\n    if (not isinstance(monotonicities, list) or not all(\n        isinstance(m, (list, tuple)) and len(m) == 2 for m in monotonicities)):\n      raise ValueError(\n          \"Monotonicities should be a list of pairs (list/tuples).\")\n    for (i, j) in monotonicities:\n      if (i < 0 or j < 0 or (num_buckets is not None and\n                             (i >= num_buckets or j >= num_buckets))):\n        raise ValueError(\n            \"Monotonicities should be pairs of be indices in range \"\n            \"[0, num_buckets). They are: {}\".format(monotonicities))\n"
  },
  {
    "path": "tensorflow_lattice/python/categorical_calibration_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for categorical calibration layer.\n\nThis test should be run with \"-c opt\" since otherwise it's slow.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import categorical_calibration_layer as categorical_calibraion\nfrom tensorflow_lattice.python import parallel_combination_layer as parallel_combination\nfrom tensorflow_lattice.python import test_utils\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass CategoricalCalibrationLayerTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(CategoricalCalibrationLayerTest, self).setUp()\n    self._disable_all = False\n    self._loss_eps = 1e-2\n    self._loss_diff_eps = 1e-4\n    keras.utils.set_random_seed(42)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  def _ScatterXUniformly(self, units, num_points, num_buckets,\n                         missing_probability, default_input_value):\n    \"\"\"Randomly uniformly scatters points across input space.\"\"\"\n    data = []\n    for unit_idx in range(units):\n      if missing_probability > 0.0:\n        missing_points = int(num_points * missing_probability)\n      else:\n        missing_points = 0\n\n      x = ([default_input_value for _ in range(missing_points)] +\n           [i % num_buckets for i in range(num_points - missing_points)])\n      np.random.seed(unit_idx)\n      np.random.shuffle(x)\n      if data:\n        data = [values + (value,) for values, value in zip(data, x)]\n      else:\n        data = [(value,) for value in x]\n\n    return [np.asarray(v, dtype=np.int32) for v in data]\n\n  def _SetDefaults(self, config):\n    config.setdefault(\"units\", 1)\n    config.setdefault(\"use_multi_calibration_layer\", False)\n    config.setdefault(\"one_d_input\", False)\n    config.setdefault(\"output_min\", None)\n    config.setdefault(\"output_max\", None)\n    config.setdefault(\"default_input_value\", None)\n    config.setdefault(\"monotonicities\", None)\n    config.setdefault(\"missing_probability\", 0.0)\n    config.setdefault(\"constraint_assertion_eps\", 1e-6)\n    config.setdefault(\"kernel_regularizer\", None)\n    config.setdefault(\"model_dir\", \"/tmp/test_pwl_model_dir/\")\n    return config\n\n  def _TrainModel(self, config):\n    \"\"\"Trains model and returns loss.\n\n    Args:\n      config: Layer config internal for this test which specifies params of\n        piecewise linear layer to train.\n\n    Returns:\n      Training loss.\n    \"\"\"\n    logging.info(\"Testing config:\")\n    logging.info(config)\n    config = self._SetDefaults(config)\n\n    self._ResetAllBackends()\n\n    if config[\"default_input_value\"] is not None:\n      # default_input_value is mapped to the last bucket, hence x_generator\n      # needs to generate in [0, ..., num_buckets-2] range.\n      num_random_buckets = config[\"num_buckets\"] - 1\n    else:\n      num_random_buckets = config[\"num_buckets\"]\n\n    # The input to the model can either be single or multi dimensional.\n    input_units = 1 if config[\"one_d_input\"] else config[\"units\"]\n\n    training_inputs = config[\"x_generator\"](\n        units=input_units,\n        num_points=config[\"num_training_records\"],\n        num_buckets=num_random_buckets,\n        missing_probability=config[\"missing_probability\"],\n        default_input_value=config[\"default_input_value\"])\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n\n    # Either create multiple CategoricalCalibration layers and combine using a\n    # ParallelCombination layer, or create a single CategoricalCalibration with\n    # multiple output dimensions.\n    if config[\"use_multi_calibration_layer\"]:\n      num_calibration_layers = config[\"units\"]\n      categorical_calibraion_units = 1\n    else:\n      num_calibration_layers = 1\n      categorical_calibraion_units = config[\"units\"]\n\n    model = keras.models.Sequential()\n    model.add(keras.layers.Input(shape=[input_units], dtype=tf.int32))\n    calibration_layers = []\n    for _ in range(num_calibration_layers):\n      calibration_layers.append(\n          categorical_calibraion.CategoricalCalibration(\n              units=categorical_calibraion_units,\n              kernel_initializer=\"constant\",\n              num_buckets=config[\"num_buckets\"],\n              output_min=config[\"output_min\"],\n              output_max=config[\"output_max\"],\n              monotonicities=config[\"monotonicities\"],\n              kernel_regularizer=config[\"kernel_regularizer\"],\n              default_input_value=config[\"default_input_value\"]))\n    if len(calibration_layers) == 1:\n      model.add(calibration_layers[0])\n    else:\n      model.add(parallel_combination.ParallelCombination(calibration_layers))\n    if config[\"units\"] > 1:\n      model.add(\n          keras.layers.Lambda(\n              lambda x: tf.reduce_mean(x, axis=1, keepdims=True)))\n    model.compile(\n        loss=keras.losses.mean_squared_error,\n        optimizer=config[\"optimizer\"](learning_rate=config[\"learning_rate\"]))\n\n    training_data = (training_inputs, training_labels)\n\n    loss = test_utils.run_training_loop(\n        config=config,\n        training_data=training_data,\n        keras_model=model,\n        input_dtype=np.int32)\n\n    assetion_ops = []\n    for calibration_layer in calibration_layers:\n      assetion_ops.extend(\n          calibration_layer.assert_constraints(\n              eps=config[\"constraint_assertion_eps\"]))\n    if not tf.executing_eagerly() and assetion_ops:\n      tf.compat.v1.keras.backend.get_session().run(assetion_ops)\n\n    return loss\n\n  @parameterized.parameters((np.mean,), (lambda x: -np.mean(x),))\n  def testUnconstrainedNoMissingValue(self, y_function):\n    if self._disable_all:\n      return\n    config = {\n        \"num_training_records\": 200,\n        \"num_training_epoch\": 500,\n        \"optimizer\": keras.optimizers.Adam,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": y_function,\n        \"num_buckets\": 10,\n        \"output_min\": None,\n        \"output_max\": None,\n        \"monotonicities\": None,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps)\n    config[\"units\"] = 3\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps)\n    config[\"one_d_input\"] = True\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps)\n\n  @parameterized.parameters((np.mean,), (lambda x: -np.mean(x),))\n  def testUnconstrainedWithMissingValue(self, y_function):\n    if self._disable_all:\n      return\n    config = {\n        \"num_training_records\": 200,\n        \"num_training_epoch\": 500,\n        \"optimizer\": keras.optimizers.Adam,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": y_function,\n        \"num_buckets\": 10,\n        \"output_min\": None,\n        \"output_max\": None,\n        \"monotonicities\": None,\n        \"default_input_value\": -1,\n        \"missing_probability\": 0.1,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps)\n    config[\"units\"] = 3\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps)\n    config[\"one_d_input\"] = True\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (0.0, 9.0, None, 0.0),\n      (1.0, 8.0, None, 0.2),\n      (1.0, 8.0, [(6, 5)], 0.25),\n      (1.0, 8.0, [(6, 5), (5, 4)], 0.4),\n      (1.0, 8.0, [(6, 5), (7, 5)], 0.4),\n      (1.0, 8.0, [(6, 5), (5, 4), (4, 3)], 0.7),\n      (1.0, 8.0, [(7, 6), (6, 5), (4, 3), (3, 2)], 0.6),\n      (1.0, 8.0, [(7, 6), (6, 5), (5, 4), (4, 3), (3, 2)], 1.95),\n  )\n  def testConstraints(self, output_min, output_max, monotonicities,\n                      expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"num_training_records\": 1000,\n        \"num_training_epoch\": 1000,\n        \"optimizer\": keras.optimizers.Adam,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": np.mean,\n        \"num_buckets\": 10,\n        \"output_min\": output_min,\n        \"output_max\": output_max,\n        \"monotonicities\": monotonicities,\n    }\n\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n    # Same input with multiple calibration units, should give out the same loss.\n    config[\"one_d_input\"] = True\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n    # With independently sampled unit-dim inputs loss is caled by 1/units.\n    config[\"one_d_input\"] = False\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(\n        loss,\n        expected_loss / config[\"units\"],\n        delta=self._loss_eps * config[\"units\"])\n\n    # Using separate calibration layers should give out the same loss.\n    config[\"use_multi_calibration_layer\"] = True\n    loss_multi_calib = self._TrainModel(config)\n    self.assertAlmostEqual(loss, loss_multi_calib, delta=self._loss_diff_eps)\n\n  def testCircularMonotonicites(self):\n    if self._disable_all:\n      return\n    config = {\n        \"num_training_records\": 200,\n        \"num_training_epoch\": 500,\n        \"optimizer\": keras.optimizers.Adam,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": float,\n        \"num_buckets\": 5,\n        \"monotonicities\": [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)],\n    }\n\n    with self.assertRaises(ValueError):\n      self._TrainModel(config)\n\n  @parameterized.parameters(\n      # Standard Keras regularizer:\n      (\n          keras.regularizers.l1_l2(l1=0.01, l2=0.001),),\n      # Tuple of regularizers:\n      (\n          (keras.regularizers.l1_l2(\n              l1=0.01, l2=0.0), keras.regularizers.l1_l2(l1=0.0, l2=0.001)),),\n  )\n  def testRegularizers(self, regularizer):\n    if self._disable_all:\n      return\n    config = {\n        \"num_training_records\": 20,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.Adam,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": lambda _: 2.0,\n        \"kernel_regularizer\": regularizer,\n        \"num_buckets\": 3,\n        \"output_min\": 0.0,\n        \"output_max\": 4.0,\n    }\n    loss = self._TrainModel(config)\n    # This loss is pure regularization loss because initializer matches target\n    # function and there was 0 training epochs.\n    self.assertAlmostEqual(loss, 0.072, delta=self._loss_eps)\n\n  def testOutputShape(self):\n    if self._disable_all:\n      return\n\n    # Not Splitting\n    units = 10\n    input_shape, output_shape = (units,), (None, units)\n    input_a = keras.layers.Input(shape=input_shape)\n    cat_cal_0 = categorical_calibraion.CategoricalCalibration(\n        num_buckets=3, units=units)\n    output = cat_cal_0(input_a)\n    self.assertAllEqual(output_shape,\n                        cat_cal_0.compute_output_shape(input_a.shape))\n    self.assertAllEqual(output_shape, output.shape)\n\n    # Splitting\n    output_shape = [(None, 1)] * units\n    cat_cal_1 = categorical_calibraion.CategoricalCalibration(\n        num_buckets=3, units=units, split_outputs=True)\n    output = cat_cal_1(input_a)\n    self.assertAllEqual(output_shape,\n                        cat_cal_1.compute_output_shape(input_a.shape))\n    self.assertAllEqual(output_shape, [o.shape for o in output])\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/cdf_layer.py",
    "content": "# Copyright 2021 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Projection free Cumulative Distribution Function layer.\n\nKeras implementation of TensorFlow Lattice CDF layer. Layer takes single or\nmulti-dimensional input and transforms it using a set of step functions. The\nlayer is naturally monotonic and bounded to the range [0, 1].\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\nfrom . import utils\n\n\nclass CDF(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Cumulative Distribution Function (CDF) layer.\n\n  Layer takes input of shape `(batch_size, input_dim)` or `(batch_size, 1)` and\n  transforms it using `input_dim` number of cumulative distribution functions,\n  which are naturally monotonic and bounded to the range [0, 1]. If multi\n  dimensional input is provided, each output will be for the corresponding\n  input, otherwise all CDF functions will act on the same input. All units share\n  the same layer configuration, but each has their separate set of trained\n  parameters. The smoothness of the cumulative distribution functions depends on\n  the number of keypoints (i.e. step functions), the activation, and input\n  scaling.\n\n  Input shape:\n  Single input should be a rank-2 tensor with shape: `(batch_size, input_dim)`\n  or `(batch_size, 1)`.\n\n  Output shape:\n  Rank-2 tensor with shape `(batch, input_dim / factor, units)` if\n  `reduction=='none'`. Otherwise a rank-2 tensor with shape\n  `(batch_size, units)`.\n\n  Attributes:\n    - All `__init__` arguments.\n    kernel: TF variable which stores weights of each cdf function.\n    input_scaling: A constant if `input_scaling_type` is `'fixed'`, and a TF\n      variable if set to `'learned'`.\n\n  Example:\n\n  ```python\n  cdf = tfl.layers.CDF(\n    num_keypoints=10,\n    units=10,\n    # You can specify the type of activation to use for the step functions.\n    activation=\"sigmoid\",\n    # You can specifyc the type of reduction to use across the input dimension.\n    reduction=\"mean\",\n    # The input scaling type determines whether or not to use a fixed value or\n    # to learn the value during training.\n    input_scaling_type=\"fixed\",\n    # You can make the layer less connected by increasing the pruning factor,\n    # which must be a divisor of both the input dimension and units.\n    sparsity_factor=1,\n  )\n  ```\n  \"\"\"\n\n  def __init__(self,\n               num_keypoints,\n               units=1,\n               activation=\"relu6\",\n               reduction=\"mean\",\n               input_scaling_init=None,\n               input_scaling_type=\"fixed\",\n               input_scaling_monotonicity=\"increasing\",\n               sparsity_factor=1,\n               kernel_initializer=\"random_uniform\",\n               **kwargs):\n    # pyformat: disable\n    \"\"\"Initializes an instance of `Lattice`.\n\n    Args:\n      num_keypoints: The number of keypoints (i.e. step functions) to use for\n        each of `units` CDF functions.\n      units: The output dimension of the layer.\n      activation: The activation function to use for the step functions. One of:\n        - `'relu6'`: The `tf.nn.relu6` function.\n        - `'sigmoid'`: The `tf.nn.sigmoid` function.\n      reduction: The reduction used for each of the `units` CDF functions to\n        combine the CDF function output for each input dimension. One of:\n        - `'mean'`: The `tf.reduce_mean` function.\n        - `'geometric_mean'`: The n'th root of the product of each of the n\n          input dimensions.\n        - `'none'`: No input reduction.\n      input_scaling_init: The value used to initialize the input scaling.\n        Defaults to `num_keypoints` if set to `None`.\n      input_scaling_type: The type of input scaling to use. One of:\n        - `'fixed'`: input scaling will be a constant with value\n          `input_scaling_init`. This will be the value used for all input\n          dimensions.\n        - `'learned_shared'`: input scaling will be a weight learned during\n          training initialized with value `input_scaling_init`. This will be the\n          value used for all input dimensions.\n        - `'learned_per_input'`: input scaling will be a weight learned during\n          training initialized with value `input_scaling_init`. A separate value\n          will be learned for each input dimension.\n      input_scaling_monotonicity: One of:\n        - `'increasing'` or `1`: input scaling will be constrained to be\n          non-negative such that the output of the layer is monotonic in each\n          dimension.\n        - `'none'` or `0`: input scaling will not be constrained and the output\n          of the layer will no be guaranteed to be monotonic.\n      sparsity_factor: The factor by which to prune the connectivity of the\n        layer. If set to `1` there will be no pruning and the layer will be\n        fully connected. If set to `>1` the layer will be partially connected\n        where the number of connections will be reduced by this factor. Must be\n        a divisor of both the `input_dim` and `units`.\n      kernel_initializer: None or one of:\n        - `'random_uniform'`: initializes parameters as uniform\n          random functions in the range [0, 1].\n        - Any Keras initializer object.\n      **kwargs: Any additional `keras.layers.Layer` arguments.\n    \"\"\"\n    # pyformat: enable\n    super(CDF, self).__init__(**kwargs)\n    self.num_keypoints = num_keypoints\n    self.units = units\n    self.activation = activation\n    self.reduction = reduction\n    if input_scaling_init is None:\n      self.input_scaling_init = float(num_keypoints)\n    else:\n      self.input_scaling_init = float(input_scaling_init)\n    self.input_scaling_type = input_scaling_type\n    self.input_scaling_monotonicity = utils.canonicalize_monotonicity(\n        input_scaling_monotonicity)\n    self.sparsity_factor = sparsity_factor\n\n    self.kernel_initializer = create_kernel_initializer(\n        kernel_initializer_id=kernel_initializer)\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    input_dim = int(input_shape[-1])\n    if input_dim % self.sparsity_factor != 0:\n      raise ValueError(\n          \"sparsity_factor ({}) must be a divisor of input_dim ({})\".format(\n              self.sparsity_factor, input_dim))\n    if self.units % self.sparsity_factor != 0:\n      raise ValueError(\n          \"sparsity_factor ({}) must be a divisor of units ({})\".format(\n              self.sparsity_factor, self.units))\n\n    # Each keypoint represents a step function defined by the activation\n    # function specified. For an activation like relu6, this represents the\n    # the hinge point.\n    self.kernel = self.add_weight(\n        \"kernel\",\n        initializer=self.kernel_initializer,\n        shape=[\n            1, input_dim, self.num_keypoints,\n            int(self.units // self.sparsity_factor)\n        ])\n\n    # Input scaling ultimately represents the slope of the step function used.\n    # If the type is \"learned_*\" then input scaling will be a variable weight\n    # that is constrained depending on the monotonicity specified.\n    if self.input_scaling_type == \"fixed\":\n      self.input_scaling = tf.constant(self.input_scaling_init)\n    elif self.input_scaling_type == \"learned_shared\":\n      self.input_scaling = self.add_weight(\n          \"input_scaling\",\n          initializer=keras.initializers.Constant(self.input_scaling_init),\n          constraint=keras.constraints.NonNeg()\n          if self.input_scaling_monotonicity else None,\n          shape=[1])\n    elif self.input_scaling_type == \"learned_per_input\":\n      self.input_scaling = self.add_weight(\n          \"input_scaling\",\n          initializer=keras.initializers.Constant(self.input_scaling_init),\n          constraint=keras.constraints.NonNeg()\n          if self.input_scaling_monotonicity else None,\n          shape=[1, input_dim, 1, 1])\n    else:\n      raise ValueError(\"Invalid input_scaling_type: {}\".format(\n          self.input_scaling_type))\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    input_dim = int(inputs.shape[-1])\n    # We add new axes to enable broadcasting.\n    x = inputs[..., tf.newaxis, tf.newaxis]\n\n    # Shape: (batch, input_dim, 1, 1)\n    #    --> (batch, input_dim, num_keypoints, units / factor)\n    #    --> (batch, input_dim, units / factor)\n    if self.activation == \"relu6\":\n      cdfs = tf.reduce_mean(\n          tf.nn.relu6(self.input_scaling * (x - self.kernel)), axis=2) / 6\n    elif self.activation == \"sigmoid\":\n      cdfs = tf.reduce_mean(\n          tf.nn.sigmoid(self.input_scaling * (x - self.kernel)), axis=2)\n    else:\n      raise ValueError(\"Invalid activation: {}\".format(self.activation))\n\n    result = cdfs\n\n    if self.sparsity_factor != 1:\n      # Shape: (batch, input_dim, units / factor)\n      #    --> (batch, input_dim / factor, units)\n      result = tf.reshape(\n          result, [-1, int(input_dim // self.sparsity_factor), self.units])\n\n    # Shape: (batch, input_dim / factor, units)\n    #.   --> (batch, units)\n    if self.reduction == \"mean\":\n      result = tf.reduce_mean(result, axis=1)\n    elif self.reduction == \"geometric_mean\":\n      num_terms = input_dim // self.sparsity_factor\n      result = tf.math.exp(\n          tf.reduce_sum(tf.math.log(result + 1e-3), axis=1) / num_terms)\n      # we use the log form above so that we can add the epsilon term\n      # tf.pow(tf.reduce_prod(cdfs, axis=1), 1. / num_terms)\n    elif self.reduction != \"none\":\n      raise ValueError(\"Invalid reduction: {}\".format(self.reduction))\n\n    return result\n\n  def get_config(self):\n    \"\"\"Standard Keras get_config() method.\"\"\"\n    config = {\n        \"num_keypoints\":\n            self.num_keypoints,\n        \"units\":\n            self.units,\n        \"activation\":\n            self.activation,\n        \"reduction\":\n            self.reduction,\n        \"input_scaling_init\":\n            self.input_scaling_init,\n        \"input_scaling_type\":\n            self.input_scaling_type,\n        \"input_scaling_monotonicity\":\n            self.input_scaling_monotonicity,\n        \"sparsity_factor\":\n            self.sparsity_factor,\n        \"kernel_initializer\":\n            keras.initializers.serialize(\n                self.kernel_initializer, use_legacy_format=True),\n    }\n    config.update(super(CDF, self).get_config())\n    return config\n\n\ndef create_kernel_initializer(kernel_initializer_id):\n  \"\"\"Returns a kernel Keras initializer object from its id.\n\n  This function is used to convert the 'kernel_initializer' parameter in the\n  constructor of `tfl.layers.CDF` into the corresponding initializer object.\n\n  Args:\n    kernel_initializer_id: See the documentation of the 'kernel_initializer'\n      parameter in the constructor of `tfl.layers.CDF`.\n\n  Returns:\n    The Keras initializer object for the `tfl.layers.CDF` kernel variable.\n  \"\"\"\n  if kernel_initializer_id in [\"random_uniform\", \"RandomUniform\"]:\n    return keras.initializers.RandomUniform(0.0, 1.0)\n  else:\n    return keras.initializers.get(kernel_initializer_id)\n"
  },
  {
    "path": "tensorflow_lattice/python/cdf_test.py",
    "content": "# Copyright 2021 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for cdf.\"\"\"\n\nimport math\n\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import cdf_layer\nfrom tensorflow_lattice.python import test_utils\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass CdfLayerTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(CdfLayerTest, self).setUp()\n    self.disable_all = False\n    self.loss_eps = 0.001\n    self.small_eps = 1e-6\n    keras.utils.set_random_seed(42)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  def _SetDefaults(self, config):\n    config.setdefault(\"input_dims\", 1)\n    config.setdefault(\"num_keypoints\", 10)\n    config.setdefault(\"units\", 1)\n    config.setdefault(\"activation\", \"relu6\")\n    config.setdefault(\"reduction\", \"mean\")\n    config.setdefault(\"input_scaling_init\", None)\n    config.setdefault(\"input_scaling_type\", \"fixed\")\n    config.setdefault(\"sparsity_factor\", 1)\n    config.setdefault(\"kernel_initializer\", \"random_uniform\")\n\n    return config\n\n  def _ScatterXUniformly(self, num_points, input_dims):\n    \"\"\"Deterministically generates num_point random points within CDF.\"\"\"\n    np.random.seed(42)\n    x = []\n    for _ in range(num_points):\n      point = [np.random.random() for _ in range(input_dims)]\n      x.append(np.asarray(point))\n    if input_dims == 1:\n      x.sort()\n    return x\n\n  def _ScatterXUniformlyExtendedRange(self, num_points, input_dims):\n    \"\"\"Extends every dimension by 1.0 on both sides and generates points.\"\"\"\n    np.random.seed(42)\n    x = []\n    for _ in range(num_points):\n      point = [np.random.random() * 2 for _ in range(input_dims)]\n      x.append(np.asarray(point))\n    if input_dims == 1:\n      x.sort()\n    return x\n\n  def _TwoDMeshGrid(self, num_points, input_dims):\n    \"\"\"Mesh grid for visualisation of 3-d surfaces via pyplot.\"\"\"\n    if input_dims != 2:\n      raise ValueError(\"2-d mesh grid is possible only for 2-d inputs. Input\"\n                       \" dimension given: %s\" % input_dims)\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points, x_min=0.0, y_min=0.0, x_max=1.0, y_max=1.0)\n\n  def _TwoDMeshGridExtendedRange(self, num_points, input_dims):\n    \"\"\"Mesh grid extended by 1.0 on every side.\"\"\"\n    if input_dims != 2:\n      raise ValueError(\n          \"2-d mesh grid is possible only for 2-d lattice. Lattice\")\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points, x_min=-1.0, y_min=-1.0, x_max=2.0, y_max=2.0)\n\n  def _Sin(self, x):\n    return math.sin(x[0])\n\n  def _SinPlusX(self, x):\n    return math.sin(x[0]) + x[0] / 3.0\n\n  def _SinPlusXNd(self, x):\n    return np.sum([math.sin(y) + y / 5.0 for y in x])\n\n  def _SinOfSum(self, x):\n    return math.sin(sum(x))\n\n  def _Square(self, x):\n    return x[0]**2\n\n  def _ScaledSum(self, x):\n    result = 0.0\n    for y in x:\n      result += y / len(x)\n    return result\n\n  def _GetTrainingInputsAndLabels(self, config):\n    \"\"\"Generates training inputs and labels.\n\n    Args:\n      config: Dictionary with config for this unit test.\n\n    Returns:\n      Tuple `(training_inputs, training_labels)` where\n        `training_inputs` and `training_labels` are data for training.\n    \"\"\"\n    raw_training_inputs = config[\"x_generator\"](\n        num_points=config[\"num_training_records\"],\n        input_dims=config[\"input_dims\"])\n\n    if isinstance(raw_training_inputs, tuple):\n      # This means that raw inputs are 2-d mesh grid. Convert them into list of\n      # 2-d points.\n      training_inputs = list(np.dstack(raw_training_inputs).reshape((-1, 2)))\n    else:\n      training_inputs = raw_training_inputs\n\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n    return training_inputs, training_labels\n\n  def _TrainModel(self, config):\n    logging.info(\"Testing config:\")\n    logging.info(config)\n    config = self._SetDefaults(config)\n    self._ResetAllBackends()\n\n    training_inputs, training_labels = (\n        self._GetTrainingInputsAndLabels(config))\n\n    keras_layer = cdf_layer.CDF(\n        num_keypoints=config[\"num_keypoints\"],\n        units=config[\"units\"],\n        activation=config[\"activation\"],\n        reduction=config[\"reduction\"],\n        input_scaling_init=config[\"input_scaling_init\"],\n        input_scaling_type=config[\"input_scaling_type\"],\n        sparsity_factor=config[\"sparsity_factor\"],\n        kernel_initializer=config[\"kernel_initializer\"],\n        input_shape=(config[\"input_dims\"],),\n        dtype=tf.float32)\n    model = keras.models.Sequential()\n    model.add(keras_layer)\n\n    # When we have multi-unit output, we average across the output units for\n    # testing.\n    if config[\"units\"] > 1:\n      model.add(\n          keras.layers.Lambda(\n              lambda x: tf.reduce_mean(x, axis=-1, keepdims=True)))\n\n    optimizer = config[\"optimizer\"](learning_rate=config[\"learning_rate\"])\n    model.compile(loss=\"mse\", optimizer=optimizer)\n\n    training_data = (training_inputs, training_labels)\n    loss = test_utils.run_training_loop(\n        config=config, training_data=training_data, keras_model=model\n    )\n\n    if tf.executing_eagerly():\n      tf.print(\"final weights: \", keras_layer.kernel)\n\n    return loss\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", \"fixed\", 0.002203),\n      (\"relu6\", \"mean\", \"learned_shared\", 0.002216),\n      (\"relu6\", \"mean\", \"learned_per_input\", 0.002216),\n      (\"relu6\", \"geometric_mean\", \"fixed\", 0.002176),\n      (\"relu6\", \"geometric_mean\", \"learned_shared\", 0.002191),\n      (\"relu6\", \"geometric_mean\", \"learned_per_input\", 0.002191),\n      (\"sigmoid\", \"mean\", \"fixed\", 0.002451),\n      (\"sigmoid\", \"mean\", \"learned_shared\", 0.002443),\n      (\"sigmoid\", \"mean\", \"learned_per_input\", 0.002443),\n      (\"sigmoid\", \"geometric_mean\", \"fixed\", 0.002419),\n      (\"sigmoid\", \"geometric_mean\", \"learned_shared\", 0.002411),\n      (\"sigmoid\", \"geometric_mean\", \"learned_per_input\", 0.002411),\n  )\n  def test1Dim(self, activation, reduction, input_scaling_type, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 1,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinPlusX,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", \"fixed\", 0.171249),\n      (\"relu6\", \"mean\", \"learned_shared\", 0.170965),\n      (\"relu6\", \"mean\", \"learned_per_input\", 0.171091),\n      (\"relu6\", \"geometric_mean\", \"fixed\", 0.172444),\n      (\"relu6\", \"geometric_mean\", \"learned_shared\", 0.172357),\n      (\"relu6\", \"geometric_mean\", \"learned_per_input\", 0.172390),\n      (\"sigmoid\", \"mean\", \"fixed\", 0.172810),\n      (\"sigmoid\", \"mean\", \"learned_shared\", 0.172517),\n      (\"sigmoid\", \"mean\", \"learned_per_input\", 0.172653),\n      (\"sigmoid\", \"geometric_mean\", \"fixed\", 0.174273),\n      (\"sigmoid\", \"geometric_mean\", \"learned_shared\", 0.174110),\n      (\"sigmoid\", \"geometric_mean\", \"learned_per_input\", 0.174110),\n  )\n  def test2Dim(self, activation, reduction, input_scaling_type, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 2,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", \"fixed\", 0.000156),\n      (\"relu6\", \"mean\", \"learned_shared\", 0.000144),\n      (\"relu6\", \"mean\", \"learned_per_input\", 0.000154),\n      (\"relu6\", \"geometric_mean\", \"fixed\", 0.000988),\n      (\"relu6\", \"geometric_mean\", \"learned_shared\", 0.000942),\n      (\"relu6\", \"geometric_mean\", \"learned_per_input\", 0.000977),\n      (\"sigmoid\", \"mean\", \"fixed\", 0.000078),\n      (\"sigmoid\", \"mean\", \"learned_shared\", 0.000078),\n      (\"sigmoid\", \"mean\", \"learned_per_input\", 0.0),\n      (\"sigmoid\", \"geometric_mean\", \"fixed\", 0.000793),\n      (\"sigmoid\", \"geometric_mean\", \"learned_shared\", 0.000794),\n      (\"sigmoid\", \"geometric_mean\", \"learned_per_input\", 0.000793),\n  )\n  def test5DimScaledSum(self, activation, reduction, input_scaling_type,\n                        expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 5,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 200,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._ScaledSum,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", \"fixed\", 0.213702),\n      (\"relu6\", \"mean\", \"learned_shared\", 0.213702),\n      (\"relu6\", \"mean\", \"learned_per_input\", 0.213702),\n      (\"relu6\", \"geometric_mean\", \"fixed\", 0.215817),\n      (\"relu6\", \"geometric_mean\", \"learned_shared\", 0.215806),\n      (\"relu6\", \"geometric_mean\", \"learned_per_input\", 0.215816),\n      (\"sigmoid\", \"mean\", \"fixed\", 0.205054),\n      (\"sigmoid\", \"mean\", \"learned_shared\", 0.204950),\n      (\"sigmoid\", \"mean\", \"learned_per_input\", 0.205030),\n      (\"sigmoid\", \"geometric_mean\", \"fixed\", 0.204511),\n      (\"sigmoid\", \"geometric_mean\", \"learned_shared\", 0.204406),\n      (\"sigmoid\", \"geometric_mean\", \"learned_per_input\", 0.204488),\n  )\n  def test5DimSinOfSum(self, activation, reduction, input_scaling_type,\n                       expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 5,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 200,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", \"fixed\", 0.000424),\n      (\"relu6\", \"mean\", \"learned_shared\", 0.000424),\n      (\"relu6\", \"mean\", \"learned_per_input\", 0.000424),\n      (\"relu6\", \"geometric_mean\", \"fixed\", 0.000439),\n      (\"relu6\", \"geometric_mean\", \"learned_shared\", 0.000439),\n      (\"relu6\", \"geometric_mean\", \"learned_per_input\", 0.000439),\n      (\"sigmoid\", \"mean\", \"fixed\", 0.000444),\n      (\"sigmoid\", \"mean\", \"learned_shared\", 0.000444),\n      (\"sigmoid\", \"mean\", \"learned_per_input\", 0.000444),\n      (\"sigmoid\", \"geometric_mean\", \"fixed\", 0.000459),\n      (\"sigmoid\", \"geometric_mean\", \"learned_shared\", 0.000459),\n      (\"sigmoid\", \"geometric_mean\", \"learned_per_input\", 0.000459),\n  )\n  def test1DimInputOutOfBounds(self, activation, reduction, input_scaling_type,\n                               expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 1,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformlyExtendedRange,\n        \"y_function\": self._Sin,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", \"fixed\", 0.339018),\n      (\"relu6\", \"mean\", \"learned_shared\", 0.338988),\n      (\"relu6\", \"mean\", \"learned_per_input\", 0.339002),\n      (\"relu6\", \"geometric_mean\", \"fixed\", 0.370072),\n      (\"relu6\", \"geometric_mean\", \"learned_shared\", 0.370105),\n      (\"relu6\", \"geometric_mean\", \"learned_per_input\", 0.370144),\n      (\"sigmoid\", \"mean\", \"fixed\", 0.340095),\n      (\"sigmoid\", \"mean\", \"learned_shared\", 0.340094),\n      (\"sigmoid\", \"mean\", \"learned_per_input\", 0.340094),\n      (\"sigmoid\", \"geometric_mean\", \"fixed\", 0.368851),\n      (\"sigmoid\", \"geometric_mean\", \"learned_shared\", 0.368849),\n      (\"sigmoid\", \"geometric_mean\", \"learned_per_input\", 0.368850),\n  )\n  def test2DimInputOutOfBounds(self, activation, reduction, input_scaling_type,\n                               expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 2,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGridExtendedRange,\n        \"y_function\": self._SinOfSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (6, 6, \"relu6\", \"mean\", \"fixed\", 3, 0.070477),\n      (8, 8, \"relu6\", \"mean\", \"learned_shared\", 4, 0.076625),\n      (8, 8, \"relu6\", \"mean\", \"learned_per_input\", 4, 0.076696),\n      (3, 3, \"relu6\", \"geometric_mean\", \"fixed\", 3, 0.031802),\n      (4, 4, \"relu6\", \"geometric_mean\", \"learned_shared\", 2, 0.049083),\n      (5, 5, \"relu6\", \"geometric_mean\", \"learned_per_input\", 2.5, 0.059841),\n      (6, 6, \"sigmoid\", \"mean\", \"fixed\", 3, 0.075446),\n      (8, 8, \"sigmoid\", \"mean\", \"learned_shared\", 4, 0.087095),\n      (8, 8, \"sigmoid\", \"mean\", \"learned_per_input\", 4, 0.087091),\n      (3, 3, \"sigmoid\", \"geometric_mean\", \"fixed\", 3, 0.033214),\n      (4, 4, \"sigmoid\", \"geometric_mean\", \"learned_shared\", 2, 0.044370),\n      (5, 5, \"sigmoid\", \"geometric_mean\", \"learned_per_input\", 2.5, 0.056680),\n  )\n  def testMultiUnitOutputSparsity(self, input_dims, units, activation,\n                                  reduction, input_scaling_type,\n                                  sparsity_factor, expected_loss):\n    if self.disable_all:\n      return\n    # Set the random seed for the initializer for consistent results.\n    kernel_initializer = keras.initializers.RandomUniform(0.0, 1.0, seed=42)\n    config = {\n        \"input_dims\": input_dims,\n        \"units\": units,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_type\": input_scaling_type,\n        \"sparsity_factor\": sparsity_factor,\n        \"kernel_initializer\": kernel_initializer,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Square,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (\"relu6\", \"mean\", 4, \"fixed\", 0.181436),\n      (\"relu6\", \"mean\", 6, \"learned_shared\", 0.173429),\n      (\"relu6\", \"mean\", 6, \"learned_per_input\", 0.174332),\n      (\"relu6\", \"geometric_mean\", 8, \"fixed\", 0.173544),\n      (\"relu6\", \"geometric_mean\", 15, \"learned_shared\", 0.172116),\n      (\"relu6\", \"geometric_mean\", 15, \"learned_per_input\", 0.172146),\n      (\"sigmoid\", \"mean\", 4, \"fixed\", 0.194161),\n      (\"sigmoid\", \"mean\", 6, \"learned_shared\", 0.177846),\n      (\"sigmoid\", \"mean\", 6, \"learned_per_input\", 0.179537),\n      (\"sigmoid\", \"geometric_mean\", 8, \"fixed\", 0.176535),\n      (\"sigmoid\", \"geometric_mean\", 15, \"learned_shared\", 0.172762),\n      (\"sigmoid\", \"geometric_mean\", 15, \"learned_per_input\", 0.172728),\n  )\n  def testInputScalingInit(self, activation, reduction, input_scaling_init,\n                           input_scaling_type, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"input_dims\": 2,\n        \"activation\": activation,\n        \"reduction\": reduction,\n        \"input_scaling_init\": input_scaling_init,\n        \"input_scaling_type\": input_scaling_type,\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (2, 10, 5, \"relu6\", \"mean\", \"fixed\", 30),\n      (2, 10, 5, \"relu6\", \"mean\", \"learned_shared\", 35),\n      (2, 10, 5, \"relu6\", \"mean\", \"learned_per_input\", 35),\n      (2, 10, 5, \"relu6\", \"geometric_mean\", \"fixed\", 36),\n      (2, 10, 5, \"relu6\", \"geometric_mean\", \"learned_shared\", 41),\n      (2, 10, 5, \"relu6\", \"geometric_mean\", \"learned_per_input\", 41),\n      (4, 20, 10, \"relu6\", \"mean\", \"fixed\", 30),\n      (4, 20, 10, \"relu6\", \"mean\", \"learned_shared\", 35),\n      (4, 20, 10, \"relu6\", \"mean\", \"learned_per_input\", 35),\n      (4, 20, 10, \"relu6\", \"geometric_mean\", \"fixed\", 36),\n      (4, 20, 10, \"relu6\", \"geometric_mean\", \"learned_shared\", 41),\n      (4, 20, 10, \"relu6\", \"geometric_mean\", \"learned_per_input\", 41),\n      (2, 10, 5, \"sigmoid\", \"mean\", \"fixed\", 28),\n      (2, 10, 5, \"sigmoid\", \"mean\", \"learned_shared\", 33),\n      (2, 10, 5, \"sigmoid\", \"mean\", \"learned_per_input\", 33),\n      (2, 10, 5, \"sigmoid\", \"geometric_mean\", \"fixed\", 34),\n      (2, 10, 5, \"sigmoid\", \"geometric_mean\", \"learned_shared\", 39),\n      (2, 10, 5, \"sigmoid\", \"geometric_mean\", \"learned_per_input\", 39),\n      (4, 20, 10, \"sigmoid\", \"mean\", \"fixed\", 28),\n      (4, 20, 10, \"sigmoid\", \"mean\", \"learned_shared\", 33),\n      (4, 20, 10, \"sigmoid\", \"mean\", \"learned_per_input\", 33),\n      (4, 20, 10, \"sigmoid\", \"geometric_mean\", \"fixed\", 34),\n      (4, 20, 10, \"sigmoid\", \"geometric_mean\", \"learned_shared\", 39),\n      (4, 20, 10, \"sigmoid\", \"geometric_mean\", \"learned_per_input\", 39),\n  )\n  def testGraphSize(self, input_dims, num_keypoints, units, activation,\n                    reduction, input_scaling_type, expected_graph_size):\n    # If this test failed then you modified core lattice interpolation logic in\n    # a way which increases number of ops in the graph. Or maybe Keras team\n    # changed something under the hood. Please ensure that this increase is\n    # unavoidable and try to minimize it.\n    if self.disable_all:\n      return\n    tf.compat.v1.disable_eager_execution()\n    tf.compat.v1.reset_default_graph()\n\n    layer = cdf_layer.CDF(\n        num_keypoints=num_keypoints,\n        units=units,\n        activation=activation,\n        reduction=reduction,\n        input_scaling_type=input_scaling_type)\n\n    input_tensor = tf.ones(shape=(1, input_dims))\n    layer(input_tensor)\n    graph_size = len(tf.compat.v1.get_default_graph().as_graph_def().node)\n\n    self.assertLessEqual(graph_size, expected_graph_size)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/conditional_cdf.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implements CDF transformation with derived parameters (kernels).\n\n`cdf_fn` is similar to `tfl.layers.CDF`, which is an additive / multiplicative\naverage of a few shifted and scaled `sigmoid` or `relu6` basis functions,\nwith the difference that the functions are parametrized by the provided\nparameters instead of learnable weights belonging to a `tfl.layers.CDF` layer.\n\nThese parameters can be one of:\n\n  - constants,\n  - trainable variables,\n  - outputs from other TF modules.\n\nFor inputs of shape `(batch_size, input_dim)`, two sets of free-form\nparameters are used to configure the CDF function:\n\n- `location_parameters` for where to place the sigmoid / relu6 transformation\nbasis,\n- `scaling_parameters` (optional) for the horizontal scaling before applying\nthe transformation basis.\n\"\"\"\n\nfrom typing import Optional, Union, Tuple\nimport tensorflow as tf\n\n\ndef _verify_cdf_params(\n    inputs: tf.Tensor,\n    location_parameters: tf.Tensor,\n    scaling_parameters: Optional[tf.Tensor],\n    units: int,\n    activation: str,\n    reduction: str,\n    sparsity_factor: int,\n) -> None:\n  \"\"\"Verifies the arguments of cdf_fn call.\n\n  Args:\n    inputs: inputs to the CDF function.\n    location_parameters: parameters for deciding the locations of the\n      transformations.\n    scaling_parameters: parameters for deciding the horizontal scaling of the\n      transformations.\n    units: output dimension.\n    activation: either `sigmoid` or `relu6` for selecting the transformation.\n    reduction: either `mean`, `geometric_mean`, or `none` to specify whether to\n      perform averaging and which average to perform.\n    sparsity_factor: deciding the level of sparsity during reduction.\n      `input_dim` and `units` should both be divisible by `sparsity_factor`.\n  \"\"\"\n  if activation not in (\"sigmoid\", \"relu6\"):\n    raise ValueError(\n        f\"activation = {activation} is not supported. Use 'sigmoid' or 'relu6'.\"\n    )\n  if reduction not in (\"mean\", \"geometric_mean\", \"none\"):\n    raise ValueError(\n        f\"reduction = {reduction} is not supported. Use 'mean',\"\n        \" 'geometric_mean' or 'none'.\"\n    )\n\n  if len(inputs.shape) != 2:\n    raise ValueError(\n        f\"inputs shape {inputs.shape} is not (batch_size, input_dim).\"\n    )\n\n  input_dim = inputs.shape[1]\n  if units % sparsity_factor != 0:\n    raise ValueError(\n        f\"units = {units} is not divisible by sparsity_factor =\"\n        f\" {sparsity_factor}.\"\n    )\n  if input_dim % sparsity_factor != 0:\n    raise ValueError(\n        f\"input_dim = {input_dim} is not divisible by sparsity_factor =\"\n        f\" {sparsity_factor}.\"\n    )\n\n  if (\n      len(location_parameters.shape) != 4\n      or location_parameters.shape[1] != input_dim\n      or location_parameters.shape[3] != units // sparsity_factor\n  ):\n    raise ValueError(\n        \"location_parameters shape\"\n        f\" {location_parameters.shape} is not (batch, input_dim, \"\n        f\"num_functions, units / sparsity_factor = {units // sparsity_factor}).\"\n    )\n\n  if scaling_parameters is not None:\n    try:\n      _ = tf.broadcast_to(\n          scaling_parameters,\n          location_parameters.shape,\n          name=\"cdf_fn_try_broadcasting\",\n      )\n    except Exception as err:\n      raise ValueError(\n          \"scaling_parameters and location_parameters likely\"\n          \" are not broadcastable. Shapes of scaling_parameters:\"\n          f\" {scaling_parameters.shape}, location_parameters:\"\n          f\" {location_parameters.shape}.\"\n      ) from err\n\n\n@tf.function\ndef cdf_fn(\n    inputs: tf.Tensor,\n    location_parameters: tf.Tensor,\n    scaling_parameters: Optional[tf.Tensor] = None,\n    units: int = 1,\n    activation: str = \"relu6\",\n    reduction: str = \"mean\",\n    sparsity_factor: int = 1,\n    scaling_exp_transform_multiplier: Optional[float] = None,\n    return_derived_parameters: bool = False,\n) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]:\n  r\"\"\"Maps `inputs` through a CDF function specified by keypoint parameters.\n\n  `cdf_fn` is similar to `tfl.layers.CDF`, which is an additive / multiplicative\n  average of a few shifted and scaled `sigmoid` or `relu6` basis functions,\n  with the difference that the functions are parametrized by the provided\n  parameters instead of learnable weights belonging to a `tfl.layers.CDF` layer.\n\n  These parameters can be one of:\n\n    - constants,\n    - trainable variables,\n    - outputs from other TF modules.\n\n  For inputs of shape `(batch_size, input_dim)`, two sets of free-form\n  parameters are used to configure the CDF function:\n\n  - `location_parameters` for where to place the sigmoid / relu6 transformation\n  basis,\n  - `scaling_parameters` (optional) for the horizontal scaling before applying\n  the transformation basis.\n\n  The transformation per dimension is `x -> activation(scale * (x - location))`,\n  where:\n\n  - `scale` (specified via `scaling_parameter`) is the input scaling for each\n  dimension and needs to be strictly positive for the CDF function to become\n  monotonic. If needed, you can set `scaling_exp_transform_multiplier` to get\n  `scale = exp(scaling_parameter * scaling_exp_transform_multiplier)` and\n  guarantees strict positivity.\n  - `location` (specified via `location_parameter`) is the input shift. Notice\n  for `relu6` this is where the transformation starts to be nonzero, whereas for\n  `sigmoid` this is where the transformation hits 0.5.\n  - `activation` is either `sigmoid` or `relu6` (for `relu6 / 6`).\n\n  An optional `reduction` operation will compute the additive / multiplicative\n  average for the input dims after their individual CDF transformation. `mean`\n  and `geometric_mean` are supported if sepcified.\n\n  `sparsity_factor` decides the level of sparsity during reduction. For\n  instance, default of `sparsity = 1` calculates the average of *all* input\n  dims, whereas `sparsity = 2` calculates the average of *every other* input\n  dim, and so on.\n\n  Input shape:\n    We denote `num_functions` as the number of `sigmoid` or `relu6 / 6` basis\n    functions used for each CDF transformation.\n\n    `inputs` should be:\n\n    - `(batch_size, input_dim)`.\n\n    `location_parameters` should be:\n\n    - `(batch_size, input_dim, num_functions, units // sparsity_factor)`.\n\n    `scaling_parameters` when provided should be broadcast friendly\n    with `location_parameters`, e.g. one of\n\n    - `(batch_size, input_dim, 1, 1)`,\n    - `(batch_size, input_dim, num_functions, 1)`,\n    - `(batch_size, input_dim, 1, units // sparsity_factor)`,\n    - `(batch_size, input_dim, num_functions, units // sparsity_factor)`.\n\n  Args:\n    inputs: inputs to the CDF function.\n    location_parameters: parameters for deciding the locations of the\n      transformations.\n    scaling_parameters: parameters for deciding the horizontal scaling of the\n      transformations.\n    units: output dimension.\n    activation: either `sigmoid` or `relu6` for selecting the transformation.\n    reduction: either `mean`, `geometric_mean`, or `none` to specify whether to\n      perform averaging and which average to perform.\n    sparsity_factor: deciding the level of sparsity during reduction.\n      `input_dim` and `units` should both be divisible by `sparsity_factor`.\n    scaling_exp_transform_multiplier: if provided, will be used inside an\n      exponential transformation for `scaling_parameters`. This can be useful if\n      `scaling_parameters` is free-form.\n    return_derived_parameters: Whether `location_parameters` and\n      `scaling_parameters` should be output along with the model output (e.g.\n      for loss function computation purpoeses).\n\n  Returns:\n    If `return_derived_parameters = False`:\n\n      - The CDF transformed outputs as a tensor with shape either\n        `(batch_size, units)` if `reduction = 'mean' / 'geometric_mean'`, or\n        `(batch_size, input_dim // sparsity_factor, units)` if\n        `reduction = 'none'`.\n\n    If `return_derived_parameters = True`:\n\n      - A tuple of three elements:\n\n        1. The CDF transformed outputs.\n        2. `location_parameters`.\n        3. `scaling_parameters`, with `exp` transformation applied if specified.\n  \"\"\"\n\n  _verify_cdf_params(\n      inputs,\n      location_parameters,\n      scaling_parameters,\n      units,\n      activation,\n      reduction,\n      sparsity_factor,\n  )\n  input_dim = inputs.shape[1]\n  x = inputs[..., tf.newaxis, tf.newaxis] - location_parameters\n  if scaling_parameters is not None:\n    if scaling_exp_transform_multiplier is not None:\n      scaling_parameters = tf.math.exp(\n          scaling_parameters * scaling_exp_transform_multiplier\n      )\n    x *= scaling_parameters\n  else:\n    # For use when return_derived_parameters = True.\n    scaling_parameters = tf.ones_like(location_parameters, dtype=tf.float32)\n\n  # Shape: (batch, input_dim, 1, 1)\n  #    --> (batch, input_dim, num_functions, units / factor)\n  #    --> (batch, input_dim, units / factor).\n  if activation == \"relu6\":\n    result = tf.reduce_mean(tf.nn.relu6(x), axis=2) / 6\n  else:  # activation == \"sigmoid\":\n    result = tf.reduce_mean(tf.nn.sigmoid(x), axis=2)\n\n  if sparsity_factor != 1:\n    # Shape: (batch, input_dim, units / factor)\n    #    --> (batch, input_dim / factor, units).\n    result = tf.reshape(result, (-1, input_dim // sparsity_factor, units))\n\n  # Shape: (batch, input_dim / factor, units) --> (batch, units).\n  if reduction == \"mean\":\n    result = tf.reduce_mean(result, axis=1)\n  elif reduction == \"geometric_mean\":\n    # We use the log form so that we can add the epsilon term\n    # tf.pow(tf.reduce_prod(cdfs, axis=1), 1. / num_terms).\n    result = tf.math.exp(tf.reduce_mean(tf.math.log(result + 1e-8), axis=1))\n  # Otherwise reduction == \"none\".\n\n  if return_derived_parameters:\n    return (result, location_parameters, scaling_parameters)\n  else:\n    return result\n"
  },
  {
    "path": "tensorflow_lattice/python/conditional_cdf_test.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TF tests for conditional_cdf.py.\"\"\"\n\nfrom absl.testing import parameterized\nimport tensorflow as tf\nfrom tensorflow_lattice.python.conditional_cdf import cdf_fn\n\n_EPSILON = 1e-4\n\n\nclass CdfFnTest(parameterized.TestCase, tf.test.TestCase):\n\n  def assertAllClose(self, x, y):\n    super().assertAllClose(x, y, atol=1e-4)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"trivial\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          reduction=\"none\",\n          expected=[[[0.29604811]], [[0.5]], [[0.70395189]]],\n      ),\n      dict(\n          testcase_name=\"trivial_mean\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          reduction=\"mean\",\n          expected=[[0.29604811], [0.5], [0.70395189]],\n      ),\n      dict(\n          testcase_name=\"moderate\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n          ],\n          reduction=\"none\",\n          expected=[\n              [[0.29604811], [0.5]],\n              [[0.5], [0.61075843]],\n              [[0.70395189], [0.66584245]],\n          ],\n      ),\n      dict(\n          testcase_name=\"moderate_scaling\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          reduction=\"none\",\n          expected=[\n              [[0.29604811], [0.5]],\n              [[0.5], [0.632815979]],\n              [[0.8310872504], [0.6666666666]],\n          ],\n      ),\n      dict(\n          testcase_name=\"moderate_mean\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=None,\n          reduction=\"mean\",\n          expected=[[0.398024055], [0.555379215], [0.684897170]],\n      ),\n      dict(\n          testcase_name=\"moderate_geometric_mean\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n          ],\n          reduction=\"geometric_mean\",\n          expected=[[0.38473894], [0.55261127], [0.68463206]],\n      ),\n  )\n  def test_compute_sigmoid(\n      self,\n      inputs,\n      location_parameters,\n      scaling_parameters,\n      reduction,\n      expected,\n  ):\n    result = cdf_fn(\n        inputs=tf.constant(inputs, dtype=tf.float32),\n        location_parameters=tf.constant(location_parameters, dtype=tf.float32),\n        scaling_parameters=(\n            tf.constant(scaling_parameters, dtype=tf.float32)\n            if scaling_parameters is not None\n            else None\n        ),\n        units=1,\n        activation=\"sigmoid\",\n        reduction=reduction,\n    )\n    self.assertAllClose(result, expected)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"trivial\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          reduction=\"none\",\n          expected=[[[0.0]], [[1.0 / 18]], [[3.0 / 18]]],\n      ),\n      dict(\n          testcase_name=\"trivial_none_scaling\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=None,\n          reduction=\"none\",\n          expected=[[[0.0]], [[1.0 / 18]], [[3.0 / 18]]],\n      ),\n      dict(\n          testcase_name=\"trivial_mean\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          reduction=\"mean\",\n          expected=[[0.0], [1.0 / 18], [3.0 / 18]],\n      ),\n      dict(\n          testcase_name=\"moderate\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n          ],\n          reduction=\"none\",\n          expected=[\n              [[0.0], [2.0 / 18]],\n              [[1.0 / 18], [5.0 / 18]],\n              [[3.0 / 18], [8.0 / 18]],\n          ],\n      ),\n      dict(\n          testcase_name=\"moderate_none_scaling\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=None,\n          reduction=\"none\",\n          expected=[\n              [[0.0], [2.0 / 18]],\n              [[1.0 / 18], [5.0 / 18]],\n              [[3.0 / 18], [8.0 / 18]],\n          ],\n      ),\n      dict(\n          testcase_name=\"moderate_scaling\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[0.5]]],\n          ],\n          reduction=\"none\",\n          expected=[\n              [[0.0], [2.0 / 18]],\n              [[2.0 / 18], [8.0 / 18]],\n              [[11.0 / 18], [4.0 / 18]],\n          ],\n      ),\n      dict(\n          testcase_name=\"moderate_mean\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n          ],\n          reduction=\"mean\",\n          expected=[[1.0 / 18], [3.0 / 18], [5.5 / 18]],\n      ),\n      dict(\n          testcase_name=\"moderate_geometric_mean\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n              [[[1.0]], [[1.0]]],\n          ],\n          reduction=\"geometric_mean\",\n          expected=[[0.0], [2.23606797 / 18], [4.898979485 / 18]],\n      ),\n  )\n  def test_compute_relu6(\n      self,\n      inputs,\n      location_parameters,\n      scaling_parameters,\n      reduction,\n      expected,\n  ):\n    result = cdf_fn(\n        inputs=tf.constant(inputs, dtype=tf.float32),\n        location_parameters=tf.constant(location_parameters, dtype=tf.float32),\n        scaling_parameters=(\n            tf.constant(scaling_parameters, dtype=tf.float32)\n            if scaling_parameters is not None\n            else None\n        ),\n        units=1,\n        activation=\"relu6\",\n        reduction=reduction,\n    )\n    self.assertAllClose(result, expected)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"0.0\",\n          scaling_exp_transform_multiplier=0.0,\n          expected=[[0.398024055], [0.555379215], [0.684897170]],\n      ),\n      dict(\n          testcase_name=\"1.0\",\n          scaling_exp_transform_multiplier=1.0,\n          expected=[[0.344373118], [0.58323046], [0.6278357037]],\n      ),\n      dict(\n          testcase_name=\"-1.0\",\n          scaling_exp_transform_multiplier=-1.0,\n          expected=[[0.4554976295], [0.51644151635], [0.66798191003]],\n      ),\n  )\n  def test_scaling_exp_transformation(\n      self, scaling_exp_transform_multiplier, expected\n  ):\n    result = cdf_fn(\n        inputs=tf.constant([[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]]),\n        location_parameters=tf.constant([\n            [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n            [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n            [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n        ]),\n        scaling_parameters=tf.constant([\n            [[[1.0]], [[1.0]]],\n            [[[0.0]], [[2.0]]],\n            [[[-1.0]], [[3.0]]],\n        ]),\n        reduction=\"mean\",\n        activation=\"sigmoid\",\n        scaling_exp_transform_multiplier=scaling_exp_transform_multiplier,\n    )\n    self.assertAllClose(result, expected)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"sigmoid_repeat\",\n          inputs=[[0.0], [0.0], [0.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          units=1,\n          activation=\"sigmoid\",\n          sparsity_factor=1,\n          scaling_exp_transform_multiplier=None,\n          expected=[\n              [\n                  [[[-0.06553731], [-0.08333334], [-0.06553732]]],\n                  [[[-0.06553731], [-0.08333334], [-0.06553732]]],\n                  [[[-0.06553731], [-0.08333334], [-0.06553732]]],\n              ],\n              [\n                  [[[-7.4505806e-09]]],\n                  [[[-7.4505806e-09]]],\n                  [[[-7.4505806e-09]]],\n              ],\n          ],\n      ),\n      dict(\n          testcase_name=\"sigmoid_trivial\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          units=1,\n          activation=\"sigmoid\",\n          sparsity_factor=1,\n          scaling_exp_transform_multiplier=None,\n          expected=[\n              [\n                  [[[-0.04934135], [-0.03880439], [-0.0207221]]],\n                  [[[-0.06553731], [-0.08333334], [-0.06553732]]],\n                  [[[-0.04927362], [-0.09227023], [-0.11732531]]],\n              ],\n              [[[[-8.0248594e-02]]], [[[-7.4505806e-09]]], [[[1.9081746e-01]]]],\n          ],\n      ),\n      dict(\n          testcase_name=\"relu6\",\n          inputs=[[-1.0], [0.0], [1.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=[[[[1.0]]], [[[1.0]]], [[[1.0]]]],\n          units=1,\n          activation=\"relu6\",\n          sparsity_factor=1,\n          scaling_exp_transform_multiplier=None,\n          expected=[\n              [\n                  [[[-0.0], [-0.0], [-0.0]]],\n                  [[[-0.00617284], [-0.0], [-0.0]]],\n                  [[[-0.01851852], [-0.01851852], [-0.0]]],\n              ],\n              [[[[0.0]]], [[[0.00617284]]], [[[0.05555556]]]],\n          ],\n      ),\n      dict(\n          testcase_name=\"units_multiplier_sigmoid\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=2,\n          activation=\"sigmoid\",\n          sparsity_factor=2,\n          scaling_exp_transform_multiplier=0.0,\n          expected=[\n              [\n                  [\n                      [[-0.04934135], [-0.03880439], [-0.0207221]],\n                      [[-0.03499786], [-0.08333334], [-0.03499787]],\n                  ],\n                  [\n                      [[-0.06553731], [-0.08333334], [-0.06553732]],\n                      [[-0.00719178], [-0.08005493], [-0.04275048]],\n                  ],\n                  [\n                      [[-0.04927362], [-0.09227023], [-0.11732531]],\n                      [[-0.00109488], [-0.04660612], [-0.04660612]],\n                  ],\n              ],\n              [[[[-0.0]], [[-0.0]]], [[[-0.0]], [[0.0]]], [[[0.0]], [[0.0]]]],\n          ],\n      ),\n      dict(\n          testcase_name=\"units_multiplier_relu6\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=2,\n          activation=\"relu6\",\n          sparsity_factor=2,\n          scaling_exp_transform_multiplier=0.01,\n          expected=[\n              [\n                  [\n                      [[-0.000], [-0.000], [-0.000]],\n                      [[-0.01259508], [-0.0], [-0.0]],\n                  ],\n                  [\n                      [[-0.00642476], [-0.0], [-0.0]],\n                      [[-0.03212379], [-0.03212379], [-0.0]],\n                  ],\n                  [\n                      [[-0.02046613], [-0.02046613], [-0.0]],\n                      [[-0.0], [-0.05392344], [-0.0]],\n                  ],\n              ],\n              [\n                  [[[0.0000000e00]], [[2.5190154e-04]]],\n                  [[[6.4247579e-05]], [[1.6061892e-03]]],\n                  [[[6.1398384e-04]], [[1.0784689e-03]]],\n              ],\n          ],\n      ),\n  )\n  def test_gradient(\n      self,\n      inputs,\n      location_parameters,\n      scaling_parameters,\n      units,\n      activation,\n      sparsity_factor,\n      scaling_exp_transform_multiplier,\n      expected,\n  ):\n    location_parameters = tf.Variable(\n        location_parameters,\n        trainable=True,\n        dtype=tf.float32,\n        name=\"location_parameters\",\n    )\n    scaling_parameters = tf.Variable(\n        scaling_parameters,\n        trainable=True,\n        dtype=tf.float32,\n        name=\"scaling_parameters\",\n    )\n\n    with tf.GradientTape() as tape:\n      y = cdf_fn(\n          inputs=tf.constant(inputs, dtype=tf.float32),\n          location_parameters=location_parameters,\n          scaling_parameters=scaling_parameters,\n          reduction=\"mean\",\n          units=units,\n          activation=activation,\n          sparsity_factor=sparsity_factor,\n          scaling_exp_transform_multiplier=scaling_exp_transform_multiplier,\n      )\n      loss = tf.reduce_sum(y * y)\n    grads = tape.gradient(loss, [location_parameters, scaling_parameters])\n    self.assertAllClose(grads, expected)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"activation\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=2,\n          activation=\"relu\",\n          reduction=\"none\",\n          sparsity_factor=2,\n          expected=\"activation = .* is not supported.*\",\n      ),\n      dict(\n          testcase_name=\"reduction\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=None,\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"some_reduction\",\n          sparsity_factor=2,\n          expected=\"reduction = .* is not supported.*\",\n      ),\n      dict(\n          testcase_name=\"input_shape\",\n          inputs=[-1.0, 0.0],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"none\",\n          sparsity_factor=2,\n          expected=\"inputs shape.*is not.*\",\n      ),\n      dict(\n          testcase_name=\"units_and_sparsity_factor\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=None,\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=3,\n          expected=\"units.*is not divisible by sparsity_factor.*\",\n      ),\n      dict(\n          testcase_name=\"input_dim_and_sparsity_factor\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=3,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=3,\n          expected=\"input_dim.*is not divisible by sparsity_factor.*\",\n      ),\n      dict(\n          testcase_name=\"location_parameters_shape_1\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[-1.0], [0.0], [1.0]],\n              [[-2.0], [0.0], [2.0]],\n              [[-1.0], [0.0], [1.0]],\n              [[-3.0], [0.0], [3.0]],\n              [[-1.0], [0.0], [1.0]],\n              [[-4.0], [0.0], [4.0]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=2,\n          expected=\"location_parameters shape.*is not.*\",\n      ),\n      dict(\n          testcase_name=\"location_parameters_shape_2\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n              [[[-1.0], [0.0], [1.0]]],\n          ],\n          scaling_parameters=None,\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=2,\n          expected=\"location_parameters shape.*is not.*\",\n      ),\n      dict(\n          testcase_name=\"location_parameters_shape_3\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=None,\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=1,\n          expected=\"location_parameters shape.*is not.*\",\n      ),\n      dict(\n          testcase_name=\"location_and_scaling_shape_1\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0]], [[1.0]]],\n              [[[2.0]], [[2.0]]],\n              [[[5.0]], [[7.0]]],\n              [[[5.0]], [[7.0]]],\n          ],\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=2,\n          expected=(\n              \"scaling_parameters and location_parameters\"\n              \" likely are not broadcastable.*\"\n          ),\n      ),\n      dict(\n          testcase_name=\"location_and_scaling_shape_2\",\n          inputs=[[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]],\n          location_parameters=[\n              [[[-1.0], [0.0], [1.0]], [[-2.0], [0.0], [2.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-3.0], [0.0], [3.0]]],\n              [[[-1.0], [0.0], [1.0]], [[-4.0], [0.0], [4.0]]],\n          ],\n          scaling_parameters=[\n              [[[1.0, 1.0]], [[1.0, 1.0]]],\n              [[[2.0, 1.0]], [[2.0, 1.0]]],\n              [[[5.0, 1.0]], [[7.0, 1.0]]],\n          ],\n          units=2,\n          activation=\"sigmoid\",\n          reduction=\"mean\",\n          sparsity_factor=2,\n          expected=(\n              \"scaling_parameters and location_parameters\"\n              \" likely are not broadcastable.*\"\n          ),\n      ),\n  )\n  def test_raise(\n      self,\n      inputs,\n      location_parameters,\n      scaling_parameters,\n      units,\n      activation,\n      reduction,\n      sparsity_factor,\n      expected,\n  ):\n    with self.assertRaisesRegex(ValueError, expected):\n      _ = cdf_fn(\n          inputs=tf.constant(inputs, dtype=tf.float32),\n          location_parameters=tf.constant(\n              location_parameters, dtype=tf.float32\n          ),\n          scaling_parameters=(\n              tf.constant(scaling_parameters, dtype=tf.float32)\n              if scaling_parameters is not None\n              else None\n          ),\n          units=units,\n          reduction=reduction,\n          activation=activation,\n          sparsity_factor=sparsity_factor,\n      )\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/conditional_pwl_calibration.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implements PWLCalibration with derived parameters (kernels).\n\n`pwl_calibration_fn` is similar to `tfl.layers.PWLCalibration` with the key\ndifference that the keypoints are decided by the given parameters instead\nof learnable weights belonging to a layer. These parameters can be one of:\n\n  - constants,\n  - trainable variables,\n  - outputs from other TF modules.\n\nFor inputs of shape `(batch_size, units)`, two sets of parameters are required\nto configure the piece-wise linear calibrator in terms of its x and y values:\n\n - `keypoint_input_parameters` for configuring the x values,\n - `keypoint_output_parameters` for configuring the y values.\n\nThis function is a general form of conditional calibration, that one input\nvariable is calibrated based on free form parameters coming from other variables\nand their transformations.\n\nShapes:\nThe last dimension sizes of `keypoint_input_parameters` (input_param_size) and\n`keypoint_output_parameters` (output_param_size) depend on the number of\nkeypoints used by the calibrator. We follow the relationships that\n\n - input_param_size = # keypoints - 2, as the leftmost and rightmost keypoints\n   are given.\n - output_param_size = # keypoints initially, and we then modify it by\n\n   1. if cyclic calibrator: output_param_size -= 1,\n   2. if clamp_min: output_param_size -= 1,\n   3. if clamp_max: output_param_size -= 1,\n   4. if need to learn how to impute missing: output_param_size += 1.\n\nThe final shapes need to be broadcast friendly with `(batch_size, units, 1)`:\n\n - `keypoint_input_parameters`:\n   `(1 or batch_size, 1 or units, input_param_size)`.\n - `keypoint_output_parameters`:\n   `(1 or batch_size, 1 or units, output_param_size)`.\n\"\"\"\n\nfrom typing import Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport tensorflow as tf\n\n\ndef _front_pad(x: tf.Tensor, constant_values: float) -> tf.Tensor:\n  return tf.pad(x, [[0, 0], [0, 0], [1, 0]], constant_values=constant_values)\n\n\ndef default_keypoint_output_parameters(\n    num_keypoints: int,\n    units: int = 1,\n    monotonicity: str = \"none\",\n    is_cyclic: bool = False,\n    clamp_min: bool = False,\n    clamp_max: bool = False,\n    derived_missing_output: bool = False,\n) -> Optional[tf.Tensor]:\n  \"\"\"Helper creating default `keypoint_output_parameters`.\n\n  Primarily used for testing.\n\n  Args:\n    num_keypoints: number of keypoints for inputs.\n    units: number of parallel calibrations on one input.\n    monotonicity: `none` or `increasing`, monotonicity of the calibration.\n    is_cyclic: whether the calibration is cyclic. Only works if `monotonicity ==\n      none`.\n    clamp_min: whether the leftmost keypoint should be clamped. Only works if\n      `monotonicity == increasing`.\n    clamp_max: whether the rightmost keypoint should be clamped. Only works if\n      `monotonicity == increasing`.\n    derived_missing_output: whether to reserve a placeholder for the missing\n      output value.\n\n  Returns:\n    A tensor with a shape of `(1, units, output_param_size)`.\n\n  Raises:\n    `ValueError` if parsing failed.\n  \"\"\"\n  if monotonicity == \"none\":\n    output_param_size = num_keypoints - is_cyclic + derived_missing_output\n    # default output = midpoint between\n    # keypoint_output_min and keypoint_output_max, flat.\n    return tf.zeros((1, units, output_param_size), dtype=tf.float32)\n  elif monotonicity == \"increasing\":\n    output_param_size = (\n        num_keypoints - clamp_min - clamp_max + derived_missing_output\n    )\n    # default output = equal increments between\n    # keypoint_output_min and keypoint_output_max.\n    return tf.zeros((1, units, output_param_size), dtype=tf.float32)\n  else:\n    raise ValueError(f\"Unknown monotonicity: {monotonicity}\")\n\n\ndef default_keypoint_input_parameters(\n    num_keypoints: Optional[int] = None,\n    keypoints: Optional[Sequence[float]] = None,\n    units: int = 1,\n) -> Optional[tf.Tensor]:\n  \"\"\"Helper creating default `keypoint_input_parameters`.\n\n  Primarily used for testing.\n\n  Args:\n    num_keypoints: number of keypoints. If provided, keypoints will be equally\n      spaced.\n    keypoints: sequence of increasing keypoints.\n    units: number of parallel calibrations on one input.\n\n  Returns:\n    A tensor with a shape of `(1, units, input_param_size)` or\n      `(1, units, input_param_size)`.\n\n  Raises:\n    `ValueError` if parsing failed.\n  \"\"\"\n  if num_keypoints is not None and num_keypoints > 2:\n    return tf.zeros((1, units, num_keypoints - 2), dtype=tf.float32)\n  elif keypoints is not None:\n    keypoints = np.array(keypoints)\n    deltas = keypoints[1:] - keypoints[:-1]\n    if np.all(deltas > 0):\n      deltas = deltas / np.sum(deltas)\n      deltas = np.log(deltas / deltas[0])[1:]\n      deltas = tf.reshape(tf.constant(deltas, dtype=tf.float32), (1, 1, -1))\n      return tf.tile(deltas, [1, units, 1])\n  else:\n    raise ValueError(\"Neither num_keypoints nor keypoints is specified.\")\n\n\ndef _verify_pwl_calibration(\n    inputs,\n    keypoint_input_parameters,\n    keypoint_output_parameters,\n    units,\n    keypoint_input_min,\n    keypoint_input_max,\n    keypoint_output_min,\n    keypoint_output_max,\n    clamp_min,\n    clamp_max,\n    monotonicity,\n    is_cyclic,\n    missing_input_value,\n    missing_output_value,\n):\n  \"\"\"Validates calibration arguments.\"\"\"\n  # Validate keypoint input_min and input_max.\n  if keypoint_input_min > keypoint_input_max:\n    raise ValueError(\n        f\"keypoint_input_min = {keypoint_input_min} > keypoint_input_max =\"\n        f\" {keypoint_input_max}.\"\n    )\n\n  # Validate pwl shape arguments.\n  if monotonicity not in (\"none\", \"increasing\"):\n    raise ValueError(\n        \"Monotonicity should be 'none' or 'increasing'. \"\n        f\"Given '{monotonicity}'.\"\n    )\n\n  if monotonicity == \"none\" and (clamp_min or clamp_max):\n    raise ValueError(\"Cannot clamp to min or max when monotonicity is 'none'.\")\n\n  if keypoint_output_min > keypoint_output_max:\n    raise ValueError(\n        f\"keypoint_output_min = {keypoint_output_min} > keypoint_output_max =\"\n        f\" {keypoint_output_max}.\"\n    )\n\n  if monotonicity == \"increasing\" and is_cyclic:\n    raise ValueError(\"Monotonicity should be 'none' when is_cyclic=True.\")\n\n  # Validate missingness indicators.\n  if missing_output_value is not None and missing_input_value is None:\n    raise ValueError(\n        \"missing_output_value is set, but missing_input_value is None\"\n    )\n\n  # Validate parameter shapes. See module level doc string for details.\n  num_keypoints = (\n      keypoint_input_parameters.shape[-1] + 2\n      if keypoint_input_parameters is not None\n      else 0\n  )\n  output_param_size = (\n      num_keypoints\n      - clamp_max\n      - clamp_min\n      - is_cyclic\n      + (missing_input_value is not None)\n      - (missing_output_value is not None)\n  )\n\n  if output_param_size <= 0:\n    raise ValueError(\n        f\"Required keypoint_output_parameters per example = {output_param_size}\"\n        \" <= 0: Creating a trivial function, e.g. identity or constant.\"\n    )\n\n  if units > 1 and len(keypoint_output_parameters.shape) != 3:\n    raise ValueError(\n        \"keypoint_output_parameters should be 3 dimensional when units > 1. \"\n        f\"Given {keypoint_output_parameters.shape}.\"\n    )\n  if (\n      len(keypoint_output_parameters.shape) == 3\n      and keypoint_output_parameters.shape[1] != units\n  ):\n    raise ValueError(\n        \"2nd dimension of keypoint_output_parameters does not match units, \"\n        f\"units = {units} vs keypoint_output_parameters = \"\n        f\"{keypoint_output_parameters.shape[1]}.\"\n    )\n  if keypoint_output_parameters.shape[-1] != output_param_size:\n    raise ValueError(\n        \"keypoint_output_parameters shape is \"\n        f\"{keypoint_output_parameters.shape} whose last dimension needs to be \"\n        f\"{output_param_size}.\"\n    )\n\n  # Validate input shape.\n  if inputs.shape[1] > 1 and inputs.shape[1] != units:\n    raise ValueError(\n        \"2nd dimension of input shape does not match units > 1, \"\n        f\"Require (batch_size, 1) or (batch_size, units = {units}).\"\n    )\n\n\ndef _compute_interpolation_weights(inputs, keypoints, lengths):\n  \"\"\"Computes weights for PWL calibration.\n\n  Args:\n    inputs: Tensor of shape: `(batch_size, units, 1)`. For multi-unit\n      calibration, broadcasting will be used if needed.\n    keypoints: Tensor of shape `(num_keypoints-1)` which represents left\n      keypoint of pieces of piecewise linear function along X axis.\n    lengths: Tensor of shape `(num_keypoints-1)` which represents lengths of\n      pieces of piecewise linear function along X axis.\n\n  Returns:\n    Interpolation weights tensor of shape: `(batch_size, units, num_keypoints)`.\n  \"\"\"\n  # weights always matches the shape of inputs.\n  weights = (inputs - keypoints) / lengths\n  weights = tf.clip_by_value(weights, 0.0, 1.0)\n  return _front_pad(weights, 1.0)\n\n\n@tf.function\ndef pwl_calibration_fn(\n    inputs: tf.Tensor,\n    keypoint_input_parameters: Optional[tf.Tensor],\n    keypoint_output_parameters: tf.Tensor,\n    keypoint_input_min: float = 0.0,\n    keypoint_input_max: float = 1.0,\n    keypoint_output_min: float = 0.0,\n    keypoint_output_max: float = 1.0,\n    units: int = 1,\n    monotonicity: str = \"none\",\n    clamp_min: bool = False,\n    clamp_max: bool = False,\n    is_cyclic: bool = False,\n    missing_input_value: Optional[float] = None,\n    missing_output_value: Optional[float] = None,\n    return_derived_parameters: bool = False,\n) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]:\n  \"\"\"Calibrates `inputs` using derived parameters (kernels).\n\n  `pwl_calibration_fn` is similar to `tfl.layers.PWLCalibration` with the key\n  difference that the keypoints are decided by the given parameters instead\n  of learnable weights belonging to a layer. These parameters can be one of:\n\n    - constants,\n    - trainable variables,\n    - outputs from other TF modules.\n\n  Shapes:\n  The last dimension of `keypoint_input_parameters` (`input_param_size`) and\n  `keypoint_output_parameters` (`output_param_size`) depend on the number of\n  keypoints used by the calibrator. We follow the relationships that\n\n  - `input_param_size = # keypoints - 2`, as the leftmost and rightmost\n    keypoints are given.\n  - `output_param_size = # keypoints` initially, and we then modify it by\n\n    1. if cyclic calibrator: `output_param_size -= 1`,\n    2. if clamp_min: `output_param_size -= 1`,\n    3. if clamp_max: `output_param_size -= 1`,\n    4. if need to learn how to impute missing: `output_param_size += 1`.\n\n  The final shapes need to be broadcast friendly with `(batch_size, units, 1)`:\n\n  - `keypoint_input_parameters`:\n    `(1 or batch_size, 1 or units, input_param_size)`.\n  - `keypoint_output_parameters`:\n    `(1 or batch_size, 1 or units, output_param_size)`.\n\n  Input shape:\n    `inputs` should be one of:\n\n      - `(batch_size, 1)` if `units == 1`.\n      - `(batch_size, 1)` or `(batch_size, units)` if `units > 1`.\n        The former will be broadcast to match units.\n\n    `keypoint_input_parameters` should be one of:\n\n      - `None` if only the leftmost and the rightmost keypoints are required.\n      - `(1, input_param_size)`.\n      - `(batch_size, input_param_size)`.\n      - `(1, 1, input_param_size)`.\n      - `(batch_size, 1, input_param_size)`.\n      - `(1, units, input_param_size)`.\n      - `(batch_size, units, input_param_size)`.\n\n    `keypoint_output_parameters` should be one of:\n\n      - `(1, output_param_size)`.\n      - `(batch_size, output_param_size)`.\n      - `(1, 1, output_param_size)`.\n      - `(batch_size, 1, output_param_size)`.\n      - `(1, units, output_param_size)`.\n      - `(batch_size, units, output_param_size)`.\n\n  Args:\n    inputs: inputs to the calibration fn.\n    keypoint_input_parameters: parameters for keypoint x's of calibration fn.\n    keypoint_output_parameters: parameters for keypoint y's of calibration fn.\n    keypoint_input_min: the leftmost keypoint.\n    keypoint_input_max: the rightmost keypoint.\n    keypoint_output_min: lower bound of the fn output.\n    keypoint_output_max: upper bound of the fn output.\n    units: number of parallel calibrations on one input.\n    monotonicity: `none` or `increasing`. Whether the calibration is monotonic.\n    clamp_min: only applies when monotonicity == `increasing`. Whether to clamp\n      the LHS keypoint to the calibration `keypoint_output_min`.\n    clamp_max: only applies when monotonicity == `increasing`. Whether to clamp\n      the RHS keypoint to the calibration `keypoint_output_max`.\n    is_cyclic: only applies when monotonicity == `none`. Whether the LHS and the\n      RHS keypoints have the same calibration output.\n    missing_input_value: if set, use as the value indicating a missing input.\n    missing_output_value: if set, use as the output for `missing_input_value`.\n    return_derived_parameters: if True, return the derived kernel parameters\n      used for interpolation.\n\n  Returns:\n    If `return_derived_parameters = False`:\n\n      - The calibrated output as a tensor with shape `(batch_size, units)`.\n\n    If `return_derived_parameters == True`:\n\n      - A tuple of three elements:\n\n        1. The calibrated output as a tensor with shape `(batch_size, units)`.\n        2. The deltas between the keypoints x's with shape\n          `(batch_size, units, # keypoints - 1)`.\n        3. The initial value and the deltas between the keypoints y's, with\n          shape shape `(batch_size, units, # keypoints)`. Apply `cumsum` will\n          reconstruct the y values.\n  \"\"\"\n  _verify_pwl_calibration(\n      inputs=inputs,\n      keypoint_input_parameters=keypoint_input_parameters,\n      keypoint_output_parameters=keypoint_output_parameters,\n      units=units,\n      keypoint_input_min=keypoint_input_min,\n      keypoint_input_max=keypoint_input_max,\n      keypoint_output_min=keypoint_output_min,\n      keypoint_output_max=keypoint_output_max,\n      clamp_min=clamp_min,\n      clamp_max=clamp_max,\n      monotonicity=monotonicity,\n      is_cyclic=is_cyclic,\n      missing_input_value=missing_input_value,\n      missing_output_value=missing_output_value,\n  )\n\n  if keypoint_input_parameters is None:\n    keypoint_input_parameters = tf.zeros((1, units, 1), dtype=tf.float32)\n  else:\n    if len(keypoint_input_parameters.shape) == 2:\n      keypoint_input_parameters = keypoint_input_parameters[:, tf.newaxis, :]\n    if keypoint_input_parameters.shape[1] == 1 and units > 1:\n      keypoint_input_parameters = tf.tile(\n          keypoint_input_parameters, [1, units, 1]\n      )\n    # Front-pad 0 to normalize softmax.\n    keypoint_input_parameters = _front_pad(keypoint_input_parameters, 0.0)\n\n  keypoint_deltas = tf.nn.softmax(keypoint_input_parameters, axis=-1) * (\n      keypoint_input_max - keypoint_input_min\n  )\n  # Front-pad `input_min` as the leftmost keypoint.\n  # Trim the rightmost keypoint not required for interpolation.\n  keypoints = (\n      tf.cumsum(keypoint_deltas, exclusive=True, axis=-1) + keypoint_input_min\n  )\n\n  # Rename since its value will be modified as part of the output.\n  kernel_outputs = keypoint_output_parameters\n  if len(kernel_outputs.shape) == 2:\n    kernel_outputs = kernel_outputs[:, tf.newaxis, :]\n  if kernel_outputs.shape[1] == 1 and units > 1:\n    kernel_outputs = tf.tile(kernel_outputs, [1, units, 1])\n\n  missing_output = None\n  if missing_input_value is not None:\n    if missing_output_value is None:\n      # The last parameter is used to derive the imputed output value after\n      # sigmoid and rescale.\n      missing_output = keypoint_output_min + tf.sigmoid(\n          kernel_outputs[:, :, -1]\n      ) * (keypoint_output_max - keypoint_output_min)\n      kernel_outputs = kernel_outputs[:, :, :-1]\n    else:\n      missing_output = tf.fill(\n          kernel_outputs[:, :, -1].shape, missing_output_value\n      )\n\n  if monotonicity == \"none\":\n    kernel_outputs = (\n        tf.sigmoid(kernel_outputs) * (keypoint_output_max - keypoint_output_min)\n        + keypoint_output_min\n    )\n    if is_cyclic:\n      kernel_outputs = tf.concat(\n          [kernel_outputs, kernel_outputs[:, :, :1]], axis=-1\n      )\n    # Transform to [initial value, delta_0, delta_1,...].\n    kernel_outputs = tf.concat(\n        [\n            kernel_outputs[:, :, :1],\n            kernel_outputs[:, :, 1:] - kernel_outputs[:, :, :-1],\n        ],\n        axis=-1,\n    )\n  else:  # monotonicity == \"increasing\"\n    # Front-pad zero to normalize softmax.\n    kernel_outputs = _front_pad(kernel_outputs, 0.0)\n    kernel_outputs = tf.nn.softmax(kernel_outputs, axis=-1) * (\n        keypoint_output_max - keypoint_output_min\n    )\n    if clamp_min:\n      # Front-pad keypoint_output_min to the kernel_outputs.\n      kernel_outputs = _front_pad(kernel_outputs, keypoint_output_min)\n    else:\n      # Add keypoint_output_min to the LHS element in the kernel_outputs.\n      # TODO: test tf.tensor_scatter_nd_add.\n      kernel_outputs = tf.concat(\n          [\n              kernel_outputs[:, :, :1] + keypoint_output_min,\n              kernel_outputs[:, :, 1:],\n          ],\n          axis=-1,\n      )\n    if not clamp_max:\n      # Drop the RHS element in the kernel_outputs which made cumsum = 1.\n      kernel_outputs = kernel_outputs[:, :, :-1]\n\n  if units > 1 and inputs.shape[-1] == 1:\n    inputs = tf.tile(inputs, [1, units])\n  weights = _compute_interpolation_weights(\n      tf.reshape(inputs, (-1, units, 1)), keypoints, keypoint_deltas\n  )\n  outputs = tf.reduce_sum(weights * kernel_outputs, axis=-1, keepdims=False)\n\n  if missing_input_value is not None:\n    outputs = tf.where(\n        tf.equal(inputs, missing_input_value), missing_output, outputs\n    )\n\n  if return_derived_parameters:\n    return outputs, keypoint_deltas, kernel_outputs\n  else:\n    return outputs\n"
  },
  {
    "path": "tensorflow_lattice/python/conditional_pwl_calibration_test.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TF tests for pwl_calibration_fn.py.\"\"\"\n\nimport tensorflow as tf\n\nfrom tensorflow_lattice.python.conditional_pwl_calibration import default_keypoint_input_parameters\nfrom tensorflow_lattice.python.conditional_pwl_calibration import pwl_calibration_fn\n\n_EPSILON = 1e-4\n\n\nclass PwlCalibrationFnTest(tf.test.TestCase):\n\n  def assertAllClose(self, x, y):\n    super().assertAllClose(x, y, rtol=_EPSILON, atol=_EPSILON)\n\n  def assertAllGreaterEqual(self, a, comparison_target):\n    super().assertAllGreaterEqual(a, comparison_target - _EPSILON)\n\n  def assertAllLessEqual(self, a, comparison_target):\n    super().assertAllLessEqual(a, comparison_target + _EPSILON)\n\n  def assertAllEqual(self, a, comparison_target):\n    super().assertAllInRange(\n        a, comparison_target - _EPSILON, comparison_target + _EPSILON\n    )\n\n  def setUp(self):\n    super().setUp()\n    self.kernel_4 = tf.constant(\n        [\n            [-0.38, -0.41, -0.34, -0.29],\n            [0.17, -0.32, 0.33, -0.1],\n        ],\n        dtype=tf.float32,\n    )\n    self.kernel_5 = tf.constant(\n        [\n            [-0.38, -0.41, -0.34, -0.29, 0.42],\n            [0.17, -0.32, 0.33, -0.1, -0.36],\n        ],\n        dtype=tf.float32,\n    )\n    self.multi_unit_kernel_4 = tf.constant(\n        [\n            [\n                [-0.26, 0.43, 0.49, 0.26],\n                [0.39, 0.42, -0.33, 0.41],\n                [0.28, 0.04, 0.46, 0.09],\n            ],\n            [\n                [-0.27, -0.23, 0.29, -0.12],\n                [-0.4, -0.24, -0.31, 0.01],\n                [0.03, 0.01, -0.42, -0.42],\n            ],\n        ],\n        dtype=tf.float32,\n    )\n    self.multi_unit_kernel_5 = tf.constant(\n        [\n            [\n                [-0.26, 0.43, 0.49, 0.26, -0.32],\n                [0.39, 0.42, -0.33, 0.41, 0.11],\n                [0.28, 0.04, 0.46, 0.09, -0.33],\n            ],\n            [\n                [-0.27, -0.23, 0.29, -0.12, 0.46],\n                [-0.4, -0.24, -0.31, 0.01, 0.21],\n                [0.03, 0.01, -0.42, -0.42, 0.37],\n            ],\n        ],\n        dtype=tf.float32,\n    )\n\n  def test_suite_none_monotonic(self):\n    \"\"\"Tests non-monotonic calibration.\"\"\"\n    # basic call\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [0.8]]),\n        keypoint_output_parameters=self.kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    print(default_keypoint_input_parameters(keypoints=[0.0, 0.1, 0.4, 1.0]))\n    self.assertAllClose(y, tf.constant([[0.41784188], [0.51060027]]))\n\n    # if is_cyclic, starting and ending keypoints give the same prediction\n    y1 = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [0.5]]),\n        keypoint_output_parameters=self.kernel_4,\n        is_cyclic=True,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.5, 0.6, 0.65, 0.7, 0.8]\n        ),\n        keypoint_input_min=0.5,\n        keypoint_input_max=0.8,\n    )\n    y2 = pwl_calibration_fn(\n        inputs=tf.constant([[0.8], [0.8]]),\n        keypoint_output_parameters=self.kernel_4,\n        is_cyclic=True,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.5, 0.6, 0.65, 0.7, 0.8]\n        ),\n        keypoint_input_min=0.5,\n        keypoint_input_max=0.8,\n    )\n    self.assertAllClose(y1, y2)\n\n    # basic multi-unit call, input needs broadcast\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [0.8]]),\n        units=3,\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    self.assertAllClose(\n        y,\n        tf.constant([\n            [0.6108614, 0.44871515, 0.5979259],\n            [0.50402266, 0.47603822, 0.39651677],\n        ]),\n    )\n\n    # basic multi-unit call\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.5, 0.5, 0.5], [0.8, 0.8, 0.8]]),\n        units=3,\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    self.assertAllClose(\n        y,\n        tf.constant([\n            [0.6108614, 0.44871515, 0.5979259],\n            [0.50402266, 0.47603822, 0.39651677],\n        ]),\n    )\n\n    # keypoint_output_min and keypoint_output_max scales correctly\n    y1 = pwl_calibration_fn(\n        inputs=tf.constant([[0.5, 0.5, 0.5], [0.8, 0.8, 0.8]]),\n        units=3,\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    y2 = pwl_calibration_fn(\n        inputs=tf.constant([[0.5, 0.5, 0.5], [0.8, 0.8, 0.8]]),\n        units=3,\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_output_min=-1.0,\n        keypoint_output_max=10.0,\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    self.assertAllClose(y1 * 11.0 - 1.0, y2)\n\n    # multi-unit is_cyclic gives cyclic predictions\n    y1 = pwl_calibration_fn(\n        inputs=tf.constant([[-0.1], [1.1]]),\n        units=3,\n        is_cyclic=True,\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.2, 0.5, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    y2 = pwl_calibration_fn(\n        inputs=tf.constant([[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]]),\n        units=3,\n        is_cyclic=True,\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.2, 0.5, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    self.assertAllClose(y1, y2)\n\n    # missing input with given missing output imputed correctly\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [-1.0]]),\n        keypoint_output_parameters=self.kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n        missing_input_value=-1.0,\n        missing_output_value=3.0,\n    )\n    self.assertAllClose(y, tf.constant([[0.41784188], [3.0]]))\n\n    # missing input imputed correctly with derived missing output\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [-1.0]]),\n        keypoint_output_parameters=self.kernel_5,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n        missing_input_value=-1.0,\n    )\n    self.assertAllClose(y, tf.constant([[0.41784188], [0.41095957]]))\n\n  def test_suite_increasing_monotonic(self):\n    \"\"\"Tests monotonic calibration.\"\"\"\n    # basic call\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [0.8]]),\n        keypoint_output_parameters=self.kernel_4,\n        monotonicity='increasing',\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    self.assertAllClose(y, tf.constant([[0.64769804], [0.7371951]]))\n\n    # outputs are monotonic\n    y1 = pwl_calibration_fn(\n        inputs=tf.constant([[-0.5], [0.3]]),\n        keypoint_output_parameters=self.kernel_4,\n        monotonicity='increasing',\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    y2 = pwl_calibration_fn(\n        inputs=tf.constant([[0.5], [0.8]]),\n        keypoint_output_parameters=self.kernel_4,\n        monotonicity='increasing',\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    y3 = pwl_calibration_fn(\n        inputs=tf.constant([[0.6], [1.2]]),\n        keypoint_output_parameters=self.kernel_4,\n        monotonicity='increasing',\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=1.0,\n    )\n    self.assertAllGreaterEqual(y2 - y1, 0.0)\n    self.assertAllGreaterEqual(y3 - y2, 0.0)\n\n    # clamp_min works as expected\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.0], [-0.2]]),\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            num_keypoints=5\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=2.0,\n        monotonicity='increasing',\n        keypoint_output_min=-10.0,\n        clamp_min=True,\n        units=3,\n    )\n    self.assertAllEqual(y, -10.0)\n\n    # clamp_out works as expected\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[2.0], [2.5]]),\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            num_keypoints=5\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=2.0,\n        monotonicity='increasing',\n        keypoint_output_max=10.0,\n        clamp_max=True,\n        units=3,\n    )\n    self.assertAllEqual(y, 10.0)\n\n    # clamp_min and clamp_out work as expected together, min\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.0, 0.0, -10.0], [-0.2, 0.0, -100.0]]),\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=2.0,\n        monotonicity='increasing',\n        keypoint_output_min=-10.0,\n        clamp_min=True,\n        keypoint_output_max=5.0,\n        clamp_max=True,\n        units=3,\n    )\n    self.assertAllEqual(y, -10.0)\n\n    # clamp_min and clamp_out work as expected together, max\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[2.0, 3.0, 4.0], [2.5, 2.5, 2.5]]),\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        units=3,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=2.0,\n        monotonicity='increasing',\n        keypoint_output_min=-10.0,\n        clamp_min=True,\n        keypoint_output_max=5.0,\n        clamp_max=True,\n    )\n    self.assertAllEqual(y, 5.0)\n\n    # clamp_min, clamp_out, missing_input_value and derived missing_output_value\n    # work as expected together\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.0, 1.0, 2.0], [-0.5, 1.5, 2.5]]),\n        keypoint_output_parameters=self.multi_unit_kernel_5,\n        units=3,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=2.0,\n        monotonicity='increasing',\n        keypoint_output_min=-10.0,\n        clamp_min=True,\n        keypoint_output_max=5.0,\n        clamp_max=True,\n        missing_input_value=-1.0,\n    )\n    self.assertAllClose(\n        y, tf.constant([[-10.0, -0.3635044, 5.0], [-10.0, 1.3930602, 5.0]])\n    )\n\n    # clamp_min, clamp_out and missing_input_value work as expected together\n    y = pwl_calibration_fn(\n        inputs=tf.constant([[0.0, -1.0, 2.0], [-0.5, -1.0, 2.5]]),\n        keypoint_output_parameters=self.multi_unit_kernel_4,\n        units=3,\n        keypoint_input_parameters=default_keypoint_input_parameters(\n            keypoints=[0.0, 0.1, 0.4, 1.0, 1.5, 2.0]\n        ),\n        keypoint_input_min=0.0,\n        keypoint_input_max=2.0,\n        monotonicity='increasing',\n        keypoint_output_min=-10.0,\n        clamp_min=True,\n        keypoint_output_max=5.0,\n        clamp_max=True,\n        missing_input_value=-1.0,\n        missing_output_value=3.0,\n    )\n    self.assertAllClose(y, tf.constant([[-10.0, 3.0, 5.0], [-10.0, 3.0, 5.0]]))\n\n  def test_gradient_step(self):\n    \"\"\"Tests gradient computation.\"\"\"\n    trainable = tf.Variable(\n        tf.zeros_like(self.multi_unit_kernel_5, dtype=tf.float32),\n        trainable=True,\n        name='trainable',\n    )\n\n    with tf.GradientTape() as tape:\n      y = pwl_calibration_fn(\n          inputs=tf.constant([[-1.0, 0.0, 1.0], [0.8, 2.0, 3.0]]),\n          keypoint_output_parameters=trainable,\n          units=3,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 1.0, 2.0]\n          ),\n          keypoint_input_min=0.0,\n          keypoint_input_max=2.0,\n          monotonicity='increasing',\n          keypoint_output_max=10.0,\n          clamp_max=True,\n          missing_input_value=-1.0,\n      )\n      loss = tf.reduce_mean(y * y)\n    grads = tape.gradient(loss, trainable)\n    self.assertAllClose(\n        grads,\n        tf.constant([\n            [\n                [0.0, 0.0, 0.0, 0.0, 4.166667],\n                [-0.26666668, -0.26666668, -0.26666668, -0.26666668, 0.0],\n                [1.0666668, 1.0666668, 1.0666668, -4.266667, 0.0],\n            ],\n            [\n                [1.3037037, 1.3037037, -0.3259262, -3.5851853, 0.0],\n                [0.0, 0.0, 0.0, 0.0, 0.0],\n                [0.0, 0.0, 0.0, 0.0, 0.0],\n            ],\n        ]),\n    )\n\n  def test_suite_raises(self):\n    \"\"\"Tests verifiable ValueErrors.\"\"\"\n\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.1, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.3, 0.1, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.2, 0.3, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          is_cyclic=True,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.3, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          units=3,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.5, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 1.0]\n          ),\n          keypoint_output_min=1.0,\n          keypoint_output_max=0.0,\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 1.0]\n          ),\n          missing_output_value=1.0,\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 1.0]\n          ),\n          keypoint_input_min=1.0,\n          keypoint_input_max=0.0,\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.multi_unit_kernel_4,\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5], [0.8]]),\n          keypoint_output_parameters=self.kernel_5,\n          monotonicity='increasing',\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 1.0]\n          ),\n      )\n    with self.assertRaises(ValueError):\n      _ = pwl_calibration_fn(\n          inputs=tf.constant([[0.5, 0.6, 0.7, 0.8], [0.0, 0.1, 0.2, 0.8]]),\n          units=3,\n          keypoint_output_parameters=self.multi_unit_kernel_5,\n          monotonicity='increasing',\n          keypoint_input_parameters=default_keypoint_input_parameters(\n              keypoints=[0.0, 0.1, 0.4, 0.7, 1.0]\n          ),\n      )\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/configs.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TFL model configuration library for canned estimators.\n\nTo construct a TFL canned estimator, construct a model configuration and pass\nit to the canned estimator constructor:\n\n```python\nfeature_columns = ...\nmodel_config = tfl.configs.CalibratedLatticeConfig(...)\nfeature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\ntrain_input_fn = create_input_fn(num_epochs=100, ...)\nestimator = tfl.estimators.CannedClassifier(\n    feature_columns=feature_columns,\n    model_config=model_config,\n    feature_analysis_input_fn=feature_analysis_input_fn)\nestimator.train(input_fn=train_input_fn)\n```\n\nSupported models are:\n\n*   **Calibrated linear model**: Constructed using\n    `tfl.configs.CalibratedLinearConfig`.\n    A calibrated linear model that applies piecewise-linear and categorical\n    calibration on the input feature, followed by a linear combination and an\n    optional output piecewise-linear calibration. When using output calibration\n    or when output bounds are specified, the linear layer will apply weighted\n    averaging on calibrated inputs.\n\n*   **Calibrated lattice model**: Constructed using\n    `tfl.configs.CalibratedLatticeConfig`.\n    A calibrated lattice model applies piecewise-linear and categorical\n    calibration on the input feature, followed by a lattice model and an\n    optional output piecewise-linear calibration.\n\n*   **Calibrated lattice ensemble model**: Constructed using\n    `tfl.configs.CalibratedLatticeEnsembleConfig`.\n    A calibrated lattice ensemble model applies piecewise-linear and categorical\n    calibration on the input feature, followed by an ensemble of lattice models\n    and an optional output piecewise-linear calibration.\n\nFeature calibration and per-feature configurations are set using\n`tfl.configs.FeatureConfig`. Feature configurations include monotonicity\nconstraints, per-feature regularization (see `tfl.configs.RegularizerConfig`),\nand lattice sizes for lattice models.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport copy\n\nfrom absl import logging\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n_HPARAM_FEATURE_PREFIX = 'feature'\n_HPARAM_REGULARIZER_PREFIX = 'regularizer'\n\n\nclass _Config(object):\n  \"\"\"Base class for configs.\"\"\"\n\n  def __init__(self, kwargs):\n    if 'self' in kwargs:\n      kwargs.pop('self')\n    if '__class__' in kwargs:\n      kwargs.pop('__class__')\n    self.__dict__ = kwargs\n\n  def __repr__(self):\n    return self.__dict__.__repr__()\n\n  def get_config(self):\n    \"\"\"Returns a configuration dictionary.\"\"\"\n    config = copy.deepcopy(self.__dict__)\n    if 'self' in config:\n      config.pop('self')\n    if '__class__' in config:\n      config.pop('__class__')\n    if 'feature_configs' in config and config['feature_configs'] is not None:\n      config['feature_configs'] = [\n          keras.utils.legacy.serialize_keras_object(feature_config)\n          for feature_config in config['feature_configs']\n      ]\n    if 'regularizer_configs' in config and config[\n        'regularizer_configs'] is not None:\n      config['regularizer_configs'] = [\n          keras.utils.legacy.serialize_keras_object(regularizer_config)\n          for regularizer_config in config['regularizer_configs']\n      ]\n    if ('reflects_trust_in' in config and\n        config['reflects_trust_in'] is not None):\n      config['reflects_trust_in'] = [\n          keras.utils.legacy.serialize_keras_object(trust_config)\n          for trust_config in config['reflects_trust_in']\n      ]\n    if 'dominates' in config and config['dominates'] is not None:\n      config['dominates'] = [\n          keras.utils.legacy.serialize_keras_object(dominance_config)\n          for dominance_config in config['dominates']\n      ]\n    return config\n\n  @classmethod\n  def deserialize_nested_configs(cls, config, custom_objects=None):\n    \"\"\"Returns a deserialized configuration dictionary.\"\"\"\n    config = copy.deepcopy(config)\n    if 'feature_configs' in config and config['feature_configs'] is not None:\n      config['feature_configs'] = [\n          keras.utils.legacy.deserialize_keras_object(\n              feature_config, custom_objects=custom_objects\n          )\n          for feature_config in config['feature_configs']\n      ]\n    if 'regularizer_configs' in config and config[\n        'regularizer_configs'] is not None:\n      config['regularizer_configs'] = [\n          keras.utils.legacy.deserialize_keras_object(\n              regularizer_config, custom_objects=custom_objects\n          )\n          for regularizer_config in config['regularizer_configs']\n      ]\n    if ('reflects_trust_in' in config and\n        config['reflects_trust_in'] is not None):\n      config['reflects_trust_in'] = [\n          keras.utils.legacy.deserialize_keras_object(\n              trust_config, custom_objects=custom_objects\n          )\n          for trust_config in config['reflects_trust_in']\n      ]\n    if 'dominates' in config and config['dominates'] is not None:\n      config['dominates'] = [\n          keras.utils.legacy.deserialize_keras_object(\n              dominance_config, custom_objects=custom_objects\n          )\n          for dominance_config in config['dominates']\n      ]\n    return config\n\n\nclass _HasFeatureConfigs(object):\n  \"\"\"Base class for configs with `feature_configs` attribute.\"\"\"\n\n  def feature_config_by_name(self, feature_name):\n    \"\"\"Returns existing or default FeatureConfig with the given name.\"\"\"\n    if self.feature_configs is None:\n      self.feature_configs = []\n    for feature_config in self.feature_configs:\n      if feature_config.name == feature_name:\n        return feature_config\n    feature_config = FeatureConfig(feature_name)\n    self.feature_configs.append(feature_config)\n    return feature_config\n\n\nclass _HasRegularizerConfigs(object):\n  \"\"\"Base class for configs with `regularizer_configs` attribute.\"\"\"\n\n  def regularizer_config_by_name(self, regularizer_name):\n    \"\"\"Returns existing or default RegularizerConfig with the given name.\"\"\"\n    if self.regularizer_configs is None:\n      self.regularizer_configs = []\n    for regularizer_config in self.regularizer_configs:\n      if regularizer_config.name == regularizer_name:\n        return regularizer_config\n    regularizer_config = RegularizerConfig(regularizer_name)\n    self.regularizer_configs.append(regularizer_config)\n    return regularizer_config\n\n\n# pylint: disable=unused-argument\n\n\nclass CalibratedLatticeEnsembleConfig(_Config, _HasFeatureConfigs,\n                                      _HasRegularizerConfigs):\n  \"\"\"Config for calibrated lattice model.\n\n  A calibrated lattice ensemble model applies piecewise-linear and categorical\n  calibration on the input feature, followed by an ensemble of lattice models\n  and an optional output piecewise-linear calibration.\n\n  The ensemble structure can be one of the following and set via the lattice\n  flag:\n\n    - Expliclit list of list of features specifying features used in each\n      submodel.\n    - A random arrangement (also called Random Tiny Lattices, or RTL).\n    - Crystals growing algorithm: This algorithm first constructs a prefitting\n      model to assess pairwise interactions between features, and then uses\n      those estimates to construct a final model that puts interacting\n      features in the same lattice. For details see \"Fast and flexible monotonic\n      functions with ensembles of lattices\", Advances in Neural Information\n      Processing Systems, 2016.\n\n  Examples:\n\n  Creating a random ensemble (RTL) model:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n      num_lattices=6,  # number of lattices\n      lattice_rank=5,  # number of features in each lattice\n      feature_configs=[...],\n  )\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  You can also construct a random ensemble (RTL) using a `tfl.layers.RTL`\n  layer so long as all features have the same lattice size:\n  ```python\n  model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n      lattices='rtl_layer',\n      num_lattices=6,  # number of lattices\n      lattice_rank=5,  # number of features in each lattice\n      feature_configs=[...],\n  )\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  To create a Crystals model, you will need to provide a *prefitting_input_fn*\n  to the estimator constructor. This input_fn is used to train the prefitting\n  model, as described above. The prefitting model does not need to be fully\n  trained, so a few epochs should be enough.\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n      lattices='crystals',  # feature arrangement method\n      num_lattices=6,  # number of lattices\n      lattice_rank=5,  # number of features in each lattice\n      feature_configs=[...],\n  )\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  prefitting_input_fn = create_input_fn(num_epochs=5, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn\n      prefitting_input_fn=prefitting_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self,\n               feature_configs=None,\n               lattices='random',\n               num_lattices=None,\n               lattice_rank=None,\n               interpolation='hypercube',\n               parameterization='all_vertices',\n               num_terms=2,\n               separate_calibrators=True,\n               use_linear_combination=False,\n               use_bias=False,\n               regularizer_configs=None,\n               output_min=None,\n               output_max=None,\n               output_calibration=False,\n               output_calibration_num_keypoints=10,\n               output_initialization='quantiles',\n               output_calibration_input_keypoints_type='fixed',\n               fix_ensemble_for_2d_constraints=True,\n               random_seed=0):\n    # pyformat: disable\n    \"\"\"Initializes a `CalibratedLatticeEnsembleConfig` instance.\n\n    Args:\n      feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n        specify configurations for each feature. If a configuration is not\n        provided for a feature, a default configuration will be used.\n      lattices: Should be one of the following:\n        - String `'random'` indicating that the features in each lattice should\n          be selected randomly\n        - String `'rtl_layer'` indicating that the features in each lattice\n          should be selected randomly using a `tfl.layers.RTL` layer. Note that\n          using a `tfl.layers.RTL` layer scales better than using separate\n          `tfl.layers.Lattice` instances for the ensemble.\n        - String `'crystals'` to use a heuristic to construct the lattice\n          ensemble based on pairwise feature interactions\n        - An explicit list of list of feature names to be used in each lattice\n          in the ensemble.\n      num_lattices: Number of lattices in the ensemble. Must be provided if\n        lattices are not explicitly provided.\n      lattice_rank: Number of features in each lattice. Must be provided if\n        lattices are not explicitly provided.\n      interpolation: One of 'hypercube' or 'simplex' interpolation. For a\n        d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas\n        'simplex' uses d+1 parameters and thus scales better. For details see\n        `tfl.lattice_lib.evaluate_with_simplex_interpolation` and\n        `tfl.lattice_lib.evaluate_with_hypercube_interpolation`.\n      parameterization: The parameterization of the lattice function class to\n        use. A lattice function is uniquely determined by specifying its value\n        on every lattice vertex. A parameterization scheme is a mapping from a\n        vector of parameters to a multidimensional array of lattice vertex\n        values. It can be one of:\n          - String `'all_vertices'`: This is the \"traditional\" parameterization\n            that keeps one scalar parameter per lattice vertex where the mapping\n            is essentially the identity map. With this scheme, the number of\n            parameters scales exponentially with the number of inputs to the\n            lattice. The underlying lattices used will be `tfl.layers.Lattice`\n            layers.\n          - String `'kronecker_factored'`: With this parameterization, for each\n            lattice input i we keep a collection of `num_terms` vectors each\n            having `feature_configs[0].lattice_size` entries (note that all\n            features must have the same lattice size). To obtain the tensor of\n            lattice vertex values, for `t=1,2,...,num_terms` we compute the\n            outer product of the `t'th` vector in each collection, multiply by a\n            per-term scale, and sum the resulting tensors. Finally, we add a\n            single shared bias parameter to each entry in the sum. With this\n            scheme, the number of parameters grows linearly with `lattice_rank`\n            (assuming lattice sizes and `num_terms` are held constant).\n            Currently, only monotonicity shape constraint and bound constraint\n            are supported for this scheme. Regularization is not currently\n            supported. The underlying lattices used will be\n            `tfl.layers.KroneckerFactoredLattice` layers.\n      num_terms: The number of terms in a lattice using `'kronecker_factored'`\n        parameterization. Ignored if parameterization is set to\n        `'all_vertices'`.\n      separate_calibrators: If features should be separately calibrated for each\n        lattice in the ensemble.\n      use_linear_combination: If set to true, a linear combination layer will be\n        used to combine ensemble outputs. Otherwise an averaging layer will be\n        used. If output is bounded or output calibration is used, then this\n        layer will be a weighted average.\n      use_bias: If a bias term should be used for the linear combination.\n      regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances\n        that apply global regularization.\n      output_min: Lower bound constraint on the output of the model.\n      output_max: Upper bound constraint on the output of the model.\n      output_calibration: If a piecewise-linear calibration should be used on\n        the output of the lattice.\n      output_calibration_num_keypoints: Number of keypoints to use for the\n        output piecewise-linear calibration.\n      output_initialization: The initial values to setup for the output of the\n        model. When using output calibration, these values are used to\n        initialize the output keypoints of the output piecewise-linear\n        calibration. Otherwise the lattice parameters will be setup to form a\n        linear function in the range of output_initialization. It can be one of:\n          - String `'quantiles'`: Output is initliazed to label quantiles, if\n            possible.\n          - String `'uniform'`: Output is initliazed uniformly in label range.\n          - A list of numbers: To be used for initialization of the output\n            lattice or output calibrator.\n      output_calibration_input_keypoints_type: One of \"fixed\" or\n        \"learned_interior\". If \"learned_interior\", keypoints are initialized to\n        the values in `pwl_calibration_input_keypoints` but then allowed to vary\n        during training, with the exception of the first and last keypoint\n        location which are fixed.\n      fix_ensemble_for_2d_constraints: A boolean indicating whether to add\n        missing features to some lattices to resolve potential 2d constraint\n        violations which require lattices from ensemble to either contain both\n        constrained features or none of them, e.g. trapezoid trust constraint\n        requires a lattice that has the \"conditional\" feature to include the\n        \"main\" feature. Note that this might increase the final lattice rank.\n      random_seed: Random seed to use for randomized lattices.\n    \"\"\"\n    # pyformat: enable\n    super(CalibratedLatticeEnsembleConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return CalibratedLatticeEnsembleConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass CalibratedLatticeConfig(_Config, _HasFeatureConfigs,\n                              _HasRegularizerConfigs):\n  \"\"\"Config for calibrated lattice model.\n\n  A calibrated lattice model applies piecewise-linear and categorical\n  calibration on the input feature, followed by a lattice model and an\n  optional output piecewise-linear calibration.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeConfig(\n      feature_configs=[...],\n  )\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self,\n               feature_configs=None,\n               interpolation='hypercube',\n               parameterization='all_vertices',\n               num_terms=2,\n               regularizer_configs=None,\n               output_min=None,\n               output_max=None,\n               output_calibration=False,\n               output_calibration_num_keypoints=10,\n               output_initialization='quantiles',\n               output_calibration_input_keypoints_type='fixed',\n               random_seed=0):\n    \"\"\"Initializes a `CalibratedLatticeConfig` instance.\n\n    Args:\n      feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n        specify configurations for each feature. If a configuration is not\n        provided for a feature, a default configuration will be used.\n      interpolation: One of 'hypercube' or 'simplex' interpolation. For a\n        d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas\n        'simplex' uses d+1 parameters and thus scales better. For details see\n        `tfl.lattice_lib.evaluate_with_simplex_interpolation` and\n        `tfl.lattice_lib.evaluate_with_hypercube_interpolation`.\n      parameterization: The parameterization of the lattice function class to\n        use. A lattice function is uniquely determined by specifying its value\n        on every lattice vertex. A parameterization scheme is a mapping from a\n        vector of parameters to a multidimensional array of lattice vertex\n        values. It can be one of:\n          - String `'all_vertices'`: This is the \"traditional\" parameterization\n            that keeps one scalar parameter per lattice vertex where the mapping\n            is essentially the identity map. With this scheme, the number of\n            parameters scales exponentially with the number of inputs to the\n            lattice. The underlying lattice used will be a `tfl.layers.Lattice`\n            layer.\n          - String `'kronecker_factored'`: With this parameterization, for each\n            lattice input i we keep a collection of `num_terms` vectors each\n            having `feature_configs[0].lattice_size` entries (note that all\n            features must have the same lattice size). To obtain the tensor of\n            lattice vertex values, for `t=1,2,...,num_terms` we compute the\n            outer product of the `t'th` vector in each collection, multiply by a\n            per-term scale, and sum the resulting tensors. Finally, we add a\n            single shared bias parameter to each entry in the sum. With this\n            scheme, the number of parameters grows linearly with\n            `len(feature_configs)` (assuming lattice sizes and `num_terms` are\n            held constant). Currently, only monotonicity shape constraint and\n            bound constraint are supported for this scheme. Regularization is\n            not currently supported. The underlying lattice used will be a\n            `tfl.layers.KroneckerFactoredLattice` layer.\n      num_terms: The number of terms in a lattice using `'kronecker_factored'`\n        parameterization. Ignored if parameterization is set to\n        `'all_vertices'`.\n      regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances\n        that apply global regularization.\n      output_min: Lower bound constraint on the output of the model.\n      output_max: Upper bound constraint on the output of the model.\n      output_calibration: If a piecewise-linear calibration should be used on\n        the output of the lattice.\n      output_calibration_num_keypoints: Number of keypoints to use for the\n        output piecewise-linear calibration.\n      output_initialization: The initial values to setup for the output of the\n        model. When using output calibration, these values are used to\n        initialize the output keypoints of the output piecewise-linear\n        calibration. Otherwise the lattice parameters will be setup to form a\n        linear function in the range of output_initialization. It can be one of:\n          - String `'quantiles'`: Output is initliazed to label quantiles, if\n            possible.\n          - String `'uniform'`: Output is initliazed uniformly in label range.\n          - A list of numbers: To be used for initialization of the output\n            lattice or output calibrator.\n      output_calibration_input_keypoints_type: One of \"fixed\" or\n        \"learned_interior\". If \"learned_interior\", keypoints are initialized to\n        the values in `pwl_calibration_input_keypoints` but then allowed to vary\n        during training, with the exception of the first and last keypoint\n        location which are fixed.\n      random_seed: Random seed to use for initialization of a lattice with\n        `'kronecker_factored'` parameterization. Ignored if parameterization is\n        set to `'all_vertices'`.\n    \"\"\"\n    super(CalibratedLatticeConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return CalibratedLatticeConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass CalibratedLinearConfig(_Config, _HasFeatureConfigs,\n                             _HasRegularizerConfigs):\n  \"\"\"Config for calibrated lattice model.\n\n  A calibrated linear model applies piecewise-linear and categorical\n  calibration on the input feature, followed by a linear combination and an\n  optional output piecewise-linear calibration. When using output calibration\n  or when output bounds are specified, the linear layer will be apply weighted\n  averaging on calibrated inputs.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.CalibratedLinearConfig(\n      feature_configs=[...],\n  )\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self,\n               feature_configs=None,\n               regularizer_configs=None,\n               use_bias=True,\n               output_min=None,\n               output_max=None,\n               output_calibration=False,\n               output_calibration_num_keypoints=10,\n               output_initialization='quantiles',\n               output_calibration_input_keypoints_type='fixed'):\n    \"\"\"Initializes a `CalibratedLinearConfig` instance.\n\n    Args:\n      feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n        specify configurations for each feature. If a configuration is not\n        provided for a feature, a default configuration will be used.\n      regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances\n        that apply global regularization.\n      use_bias: If a bias term should be used for the linear combination.\n      output_min: Lower bound constraint on the output of the model.\n      output_max: Upper bound constraint on the output of the model.\n      output_calibration: If a piecewise-linear calibration should be used on\n        the output of the lattice.\n      output_calibration_num_keypoints: Number of keypoints to use for the\n        output piecewise-linear calibration.\n      output_initialization: The initial values to setup for the output of the\n        model. When using output calibration, these values are used to\n        initialize the output keypoints of the output piecewise-linear\n        calibration. Otherwise the lattice parameters will be setup to form a\n        linear function in the range of output_initialization. It can be one of:\n          - String `'quantiles'`: Output is initliazed to label quantiles, if\n            possible.\n          - String `'uniform'`: Output is initliazed uniformly in label range.\n          - A list of numbers: To be used for initialization of the output\n            lattice or output calibrator.\n      output_calibration_input_keypoints_type: One of \"fixed\" or\n        \"learned_interior\". If \"learned_interior\", keypoints are initialized to\n        the values in `pwl_calibration_input_keypoints` but then allowed to vary\n        during training, with the exception of the first and last keypoint\n        location which are fixed.\n    \"\"\"\n    super(CalibratedLinearConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return CalibratedLinearConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\n# TODO: add option for different pre-aggregation model (linear/ensemble)\nclass AggregateFunctionConfig(_Config, _HasFeatureConfigs,\n                              _HasRegularizerConfigs):\n  \"\"\"Config for aggregate function learning model.\n\n  An aggregate function learning model applies piecewise-linear and categorical\n  calibration on the ragged input features, followed by an aggregation layer\n  that aggregates the calibrated inputs. Lastly a lattice model and an optional\n  output piecewise-linear calibration are applied.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.AggregateFunctionConfig(\n      feature_configs=[...],\n  )\n  model = tfl.premade.AggregateFunction(model_config)\n  model.compile(...)\n  model.fit(...)\n  model.evaluate(...)\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self,\n               feature_configs,\n               regularizer_configs=None,\n               middle_dimension=1,\n               middle_lattice_size=2,\n               middle_calibration=False,\n               middle_calibration_num_keypoints=10,\n               middle_calibration_input_keypoints_type='fixed',\n               middle_monotonicity=None,\n               middle_lattice_interpolation='hypercube',\n               aggregation_lattice_interpolation='hypercube',\n               output_min=None,\n               output_max=None,\n               output_calibration=False,\n               output_calibration_num_keypoints=10,\n               output_initialization='uniform',\n               output_calibration_input_keypoints_type='fixed'):\n    \"\"\"Initializes an `AggregateFunctionConfig` instance.\n\n    Args:\n      feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n        specify configurations for each feature.\n      regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances\n        that apply global regularization.\n      middle_dimension: The number of calibrated lattices that are applied to\n        each block. The outputs of these lattices are then averaged over the\n        blocks, and the middle_dimension resulting numbers are then passed into\n        the \"middle\" calibrated lattice. This middle lattice therefore has input\n        dimension equal to middle_dimension.\n      middle_lattice_size: Size of each of the middle_lattice dimensions.\n      middle_calibration: If a piecewise-linear calibration should be used on\n        the inputs to the middle lattice.\n      middle_calibration_num_keypoints: Number of keypoints to use for the\n        middle piecewise-linear calibration.\n      middle_calibration_input_keypoints_type: One of \"fixed\" or\n        \"learned_interior\". If \"learned_interior\", keypoints are initialized to\n        the values in `pwl_calibration_input_keypoints` but then allowed to vary\n        during training, with the exception of the first and last keypoint\n        location which are fixed.\n      middle_monotonicity: Specifies if the middle calibrators should be\n        monotonic, using 'increasing' or 1 to indicate increasing monotonicity,\n        'decreasing' or -1 to indicate decreasing monotonicity, and 'none' or 0\n        to indicate no monotonicity constraints.\n      middle_lattice_interpolation: One of 'hypercube' or 'simplex'. For a\n        d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas\n        'simplex' uses d+1 parameters and thus scales better. For details see\n        `tfl.lattice_lib.evaluate_with_simplex_interpolation` and\n        `tfl.lattice_lib.evaluate_with_hypercube_interpolation`.\n      aggregation_lattice_interpolation: One of 'hypercube' or 'simplex'. For a\n        d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas\n        'simplex' uses d+1 parameters and thus scales better. For details see\n        `tfl.lattice_lib.evaluate_with_simplex_interpolation` and\n        `tfl.lattice_lib.evaluate_with_hypercube_interpolation`.\n      output_min: Lower bound constraint on the output of the model.\n      output_max: Upper bound constraint on the output of the model.\n      output_calibration: If a piecewise-linear calibration should be used on\n        the output of the lattice.\n      output_calibration_num_keypoints: Number of keypoints to use for the\n        output piecewise-linear calibration.\n      output_initialization: The initial values to setup for the output of the\n        model. When using output calibration, these values are used to\n        initialize the output keypoints of the output piecewise-linear\n        calibration. Otherwise the lattice parameters will be setup to form a\n        linear function in the range of output_initialization. It can be one of:\n          - String `'uniform'`: Output is initliazed uniformly in label range.\n          - A list of numbers: To be used for initialization of the output\n            lattice or output calibrator.\n      output_calibration_input_keypoints_type: One of \"fixed\" or\n        \"learned_interior\". If \"learned_interior\", keypoints are initialized to\n        the values in `pwl_calibration_input_keypoints` but then allowed to vary\n        during training, with the exception of the first and last keypoint\n        location which are fixed.\n    \"\"\"\n    super(AggregateFunctionConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return AggregateFunctionConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass FeatureConfig(_Config, _HasRegularizerConfigs):\n  \"\"\"Per-feature configuration for TFL canned estimators.\n\n  A feature can either be numerical or categorical. Numeric features will be\n  calibrated using a piecewise-linear function with the given number of\n  keypoints. Categorical features should have `num_buckets > 0` and the\n  `vocabulary_list` represent their categories. Several of the config fields\n  can be filled in automatically based on the `FeatureColumns` used by the\n  model but can also be provided explicitly. See `__init__` args comments for\n  details.\n\n  Currently only one dimensional feature are supported.\n\n  Examples:\n\n  ```python\n  feature_columns = [\n      tf.feature_column.numeric_column.numeric_column(\n          'age', default_value=-1),\n      tf.feature_column.numeric_column.categorical_column_with_vocabulary_list(\n          'thal', vocabulary_list=['normal', 'fixed', 'reversible']),\n      ...\n  ]\n\n  model_config = tfl.configs.CalibratedLatticeConfig(\n      feature_configs=[\n          tfl.configs.FeatureConfig(\n              name='age',\n              lattice_size=3,\n              # Monotonically increasing.\n              monotonicity='increasing',\n              # Per feature regularization.\n              regularizer_configs=[\n                  tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n              ],\n          ),\n          tfl.configs.FeatureConfig(\n              name='thal',\n              # Partial monotonicity:\n              # output(normal) <= output(fixed)\n              # output(normal) <= output(reversible)\n              monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n          ),\n      ],\n      # Global regularizers\n      regularizer_configs=[...])\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self,\n               name,\n               is_missing_name=None,\n               default_value=None,\n               lattice_size=2,\n               monotonicity='none',\n               unimodality='none',\n               reflects_trust_in=None,\n               dominates=None,\n               pwl_calibration_always_monotonic=False,\n               pwl_calibration_convexity=0,\n               pwl_calibration_num_keypoints=10,\n               pwl_calibration_input_keypoints='quantiles',\n               pwl_calibration_input_keypoints_type='fixed',\n               pwl_calibration_clip_min=None,\n               pwl_calibration_clip_max=None,\n               pwl_calibration_clamp_min=False,\n               pwl_calibration_clamp_max=False,\n               num_buckets=0,\n               vocabulary_list=None,\n               regularizer_configs=None):\n    \"\"\"Initializes a `FeatureConfig` instance.\n\n    Args:\n      name: The name of the feature, which should match the name of a given\n        FeatureColumn or a key in the input feature dict.\n      is_missing_name: The name of a FeatureColumn or key in the input feature\n        dict that indicates missing-ness of the main feature.\n      default_value: [Automatically filled in from `FeatureColumns`] If set,\n        this value in the input value represents missing. For numeric features,\n        the output will be imputed. If default_value is provided for a\n        categocial features, it would corresponds to the last bucket counted in\n        num_buckets.\n      lattice_size: The number of lattice verticies to be used along the axis\n        for this feature.\n      monotonicity: - For numeric features, specifies if the model output should\n        be monotonic in this feature, using 'increasing' or 1 to indicate\n        increasing monotonicity, 'decreasing' or -1 to indicate decreasing\n        monotonicity, and 'none' or 0 to indicate no monotonicity constraints. -\n        For categorical features, a list of (category_a, category_b) pairs from\n        the vocabulary list indicating that with other features fixed, model\n        output for category_b should be greater than or equal to category_a. If\n        no vocabulary list is specified, we assume implcit vocabulary in the\n        range `[0, num_buckets - 1]`.\n      unimodality: For numeric features specifies if the model output should be\n        unimodal in corresponding feature, using 'valley' or 1 to indicate that\n        function first decreases then increases, using 'peak' or -1 to indicate\n        that funciton first increases then decreases, using 'none' or 0 to\n        indicate no unimodality constraints. Not used for categorical features.\n      reflects_trust_in: None or a list of `tfl.configs.TrustConfig` instances.\n      dominates: None or a list of `tfl.configs.DominanceConfig` instances.\n      pwl_calibration_always_monotonic: Specifies if the piecewise-linear\n        calibration should always be monotonic regardless of the specified\n        end-to-end model output `monotonicity` with respect to this feature.\n      pwl_calibration_convexity: Spefices the convexity constraints of the\n        calibrators for numeric features. Convexity is indicated by 'convex' or\n        1, concavity is indicated by 'concave' or -1, 'none' or 0 indicates no\n        convexity/concavity constraints. Does not affect categorical features.\n        Concavity together with increasing monotonicity as well as convexity\n        together with decreasing monotonicity results in diminishing return\n        constraints.\n      pwl_calibration_num_keypoints: Number of keypoints to use for\n        piecewise-linear calibration.\n      pwl_calibration_input_keypoints: Indicates what should be used for the\n        input keypoints of the piecewise-linear calibration. It can be one of:\n          - String `'quantiles'`: Input keypoints are set to feature quantiles.\n          - String `'uniform'`: Input keypoints are uniformly spaced in feature\n            range.\n          - A list of numbers: Explicitly specifies the keypoints.\n      pwl_calibration_input_keypoints_type: One of \"fixed\" or\n        \"learned_interior\". If \"learned_interior\", keypoints are initialized to\n        the values in `pwl_calibration_input_keypoints` but then allowed to vary\n        during training, with the exception of the first and last keypoint\n        location which are fixed. Convexity can only be imposed with \"fixed\".\n      pwl_calibration_clip_min: Input values are lower clipped by this value.\n      pwl_calibration_clip_max: Input values are upper clipped by this value.\n      pwl_calibration_clamp_min: for monotonic calibrators ensures that the\n        minimum value in calibration output is reached.\n      pwl_calibration_clamp_max: for monotonic calibrators ensures that the\n        maximum value in calibration output is reached.\n      num_buckets: [Automatically filled in from `FeatureColumns`] Number of\n        categories for a categorical feature. Out-of-vocabulary and\n        missing/default value should be counted into num_buckets (last buckets).\n      vocabulary_list: [Automatically filled in from `FeatureColumns`] The input\n        vocabulary of the feature.\n      regularizer_configs: None or a list of per-feature\n        `tfl.configs.RegularizerConfig` instances.\n    \"\"\"\n    super(FeatureConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return FeatureConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass RegularizerConfig(_Config):\n  \"\"\"Regularizer configuration for TFL canned estimators.\n\n  Regularizers can either be applied to specific features, or can be applied\n  globally to all features or lattices.\n\n\n  * **Calibrator regularizers:**\n\n    These regularizers are applied to PWL calibration layers.\n\n    - `'calib_laplacian'`: Creates an instance of\n      `tfl.pwl_calibration_layer.LaplacianRegularizer`. A calibrator laplacian\n      regularizer penalizes the changes in the output and results in a *flatter\n      calibration function*.\n    - `'calib_hessian'`: Creates an instance of\n      `tfl.pwl_calibration_layer.HessianRegularizer`. A calibrator hessian\n      regularizer penalizes changes in the slope, resulting in a *more linear\n      calibration*.\n    - `'calib_wrinkle'`: Creates an instance of\n      `tfl.pwl_calibration_layer.WrinkleRegularizer`. A calibrator wrinkle\n      regularizer penalizes the second derivative, resulting in a smoother\n      function with *less changes in the curvature*.\n\n\n  * **Lattice regularizers:**\n\n    These regularizers are applied to lattice layers.\n\n    - `'laplacian'`: Creates an instance of\n      `tfl.lattice_layer.LaplacianRegularizer`. Laplacian regularizers penalize\n      the difference between adjacent vertices in multi-cell lattice, resulting\n      in a *flatter lattice function*.\n    - `'torsion'`: Creates an instance of\n      `tfl.lattice_layer.TorsionRegularizer`. Torsion regularizers penalizes\n      how much the lattice function twists from side-to-side, a non-linear\n      interactions in each 2 x 2 cell. Using this regularization results in a\n      *more linear lattice function*.\n\n\n  Examples:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeConfig(\n      feature_configs=[\n          tfl.configs.FeatureConfig(\n              name='age',\n              lattice_size=3,\n              # Per feature regularization.\n              regularizer_configs=[\n                  tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n              ],\n          ),\n          tfl.configs.FeatureConfig(\n              name='thal',\n              # Partial monotonicity:\n              # output(normal) <= output(fixed)\n              # output(normal) <= output(reversible)\n              monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n          ),\n      ],\n      # Global regularizers\n      regularizer_configs=[\n          # Torsion regularizer applied to the lattice to make it more linear.\n          configs.RegularizerConfig(name='torsion', l2=1e-4),\n          # Globally defined calibration regularizer is applied to all features.\n          configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n      ])\n  feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)\n  train_input_fn = create_input_fn(num_epochs=100, ...)\n  estimator = tfl.estimators.CannedClassifier(\n      feature_columns=feature_columns,\n      model_config=model_config,\n      feature_analysis_input_fn=feature_analysis_input_fn)\n  estimator.train(input_fn=train_input_fn)\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self, name, l1=0.0, l2=0.0):\n    \"\"\"Initializes a `RegularizerConfig` instance.\n\n    Args:\n      name: The name of the regularizer.\n      l1: l1 regularization amount.\n      l2: l2 regularization amount.\n    \"\"\"\n    super(RegularizerConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return RegularizerConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass TrustConfig(_Config):\n  \"\"\"Configuration for feature trusts in TFL canned estimators.\n\n  You can specify how a feature reflects trust in another feature. Supported\n  trust types (see `tfl.layers.Lattice` for details):\n\n  - `'edgeworth'`: Edgeworth trust constrains the function to be more\n      responsive to a main feature as a secondary conditional feature increases\n      or decreases. For example, we may want the model to rely more on average\n      rating (main feature) when the number of reviews (conditional feature) is\n      high. In particular, the constraint guarantees that a given change in the\n      main feature's value will change the model output by more when a secondary\n      feature indicates higher trust in the main feature. Note that the\n      constraint only works when the model is monotonic in the main feature.\n  - `'trapezoid'`: Trapezoid trust is conceptually similar to edgeworth trust,\n      but this constraint guarantees that the range of possible outputs along\n      the main feature dimension, when a conditional feature indicates low\n      trust, is a *subset* of the range of outputs when a conditional feature\n      indicates high trust. When lattices have 2 vertices in each constrained\n      dimension, this implies edgeworth trust (which only constrains the size of\n      the relevant ranges). With more than 2 lattice vertices per dimension, the\n      two constraints diverge and are not necessarily 'weaker' or 'stronger'\n      than each other - edgeworth trust acts throughout the lattice interior on\n      delta shifts in the main feature, while trapezoid trust only acts on the\n      min and max extremes of the main feature, constraining the overall range\n      of outputs across the domain of the main feature. The two types of trust\n      constraints can be applied jointly.\n\n  Trust constraints only affect lattices. When using trapezoid constraints in\n  ensemble models, note that if a conditional feature is used in a lattice\n  without the main feature also being used in the same lattice, then the\n  trapezoid constraint might be violated for the ensemble function.\n\n  Exampes:\n\n  One feature reflecting trust in another:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeConfig(\n      feature_configs=[\n          tfl.configs.FeatureConfig(\n              name='num_reviews',\n              reflects_trust_in=[\n                  configs.TrustConfig(\n                      feature_name='average_rating', trust_type='edgeworth'),\n              ],\n          ),\n          tfl.configs.FeatureConfig(\n              name='average_rating',\n          ),\n      ])\n  ```\n\n  Features can reflect positive or negative trust in other features. For example\n  if the task is to estimate a property price in a neighborhood given two\n  average prices for commercial and residential properties, you can use a trust\n  feature `percentage_commercial_properties` to indicate that the model should\n  more responsive to commercial estimate if more properties are commercial in\n  the neighborhood. You can simultaneously have a negative trust constratins for\n  residential properties, since higher commercial land usage indicates fewer\n  houses, hence less market influence and less accurate estimate for residential\n  property prices.\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeConfig(\n      feature_configs=[\n          tfl.configs.FeatureConfig(\n              name='percentage_commercial_properties',\n              reflects_trust_in=[\n                  configs.TrustConfig(\n                      feature_name='average_commercial_property_price',\n                      direction='positive'),\n                  configs.TrustConfig(\n                      feature_name='average_residential_property_price',\n                      direction='negative'),\n              ],\n          ),\n          tfl.configs.FeatureConfig(\n              name='average_commercial_property_price',\n          ),\n          tfl.configs.FeatureConfig(\n              name='average_residential_property_price',\n          ),\n          tfl.configs.FeatureConfig(\n              name='square_footage',\n          ),\n          ...\n      ])\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self,\n               feature_name,\n               trust_type='edgeworth',\n               direction='positive'):\n    \"\"\"Initializes a `TrustConfig` instance.\n\n    Args:\n      feature_name: Name of the \"main\" feature for the trust constraint.\n      trust_type: Type of trust constraint. Either `'edgeworth'` or\n        `'trapezoid'`.\n      direction: Direction of the trust. Should be: `'positive'`, `'negative'`,\n        1 or -1.\n    \"\"\"\n    super(TrustConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return TrustConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass DominanceConfig(_Config):\n  \"\"\"Configuration for dominance constraints in TFL canned estimators.\n\n  You can specify how a feature dominantes another feature. Supported dominance\n  types (see `tfl.layers.Lattice` and `tfl.layers.Linear` for details):\n\n  - `'monotonic'`: Monotonic dominance constrains the function to require the\n      effect (slope) in the direction of the *dominant* dimension to be greater\n      than that of the *weak* dimension for any point in both lattice and linear\n      models. Both dominant and weak dimensions must be monotonic. The\n      constraint is guranteed to satisfy at the end of training for linear\n      models, but might not be strictly satisified for lattice models. In such\n      cases, increase the number of projection iterations.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeConfig(\n      feature_configs=[\n          tfl.configs.FeatureConfig(\n              name='num_purchases',\n              dominates=[\n                  configs.DominanceConfig(\n                      feature_name='num_clicks', dominance_type='monotonic'),\n              ],\n          ),\n          tfl.configs.FeatureConfig(\n              name='num_clicks',\n          ),\n      ])\n  ```\n  \"\"\"\n  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.\n\n  def __init__(self, feature_name, dominance_type='monotonic'):\n    \"\"\"Initializes a `DominanceConfig` instance.\n\n    Args:\n      feature_name: Name of the `\"dominant\"` feature for the dominance\n        constraint.\n      dominance_type: Type of dominance constraint. Currently, supports\n        `'monotonic'`.\n    \"\"\"\n    super(DominanceConfig, self).__init__(locals())\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    return DominanceConfig(**_Config.deserialize_nested_configs(\n        config, custom_objects=custom_objects))\n\n\nclass _TypeDict(collections.defaultdict):\n  \"\"\"Type dict that defaults to string type for hparams.\"\"\"\n\n  def __init__(self, hparams):\n    super(_TypeDict,\n          self).__init__(lambda: str,\n                         {k: type(v) for k, v in hparams.values().items()})\n\n  def __contains__(self, _):\n    return True\n\n\ndef apply_updates(model_config, updates):\n  \"\"\"Updates a model config with the given set of (key, values) updates.\n\n  Any value passed in the updates that matches a field of the config will be\n  applied to the config. Nested configs can be updated as follows: to add/update\n  a field `FIELD` in feature config for feature `FEATURE`, use\n  `feature__FEATURE__FIELD` as the key. To add/update a field `FIELD` for\n  regularizer with name `REGULARIZER` use `regularizer__REGULARIZER__FIELD` as\n  the key. This naming scheme can be nested. When possible, string values will\n  be converted to the corresponding value type in the model config.\n\n  Example:\n\n  ```python\n  model_config = ...\n  updates = [\n      ('output_max', 1),\n      ('regularizer__torsion__l1', 0.001),\n      ('feature__some_feature_name__lattice_size', 4),\n      ('feature__some_feature_name__regularizer__calib_hessian__l2', 0.001),\n      ('unrelated_haparam_not_affecting_model_config', 42),\n  ]\n  configs.apply_updates(model_config, updates)\n  ```\n\n  Arguments:\n    model_config: The model config object to apply the updates to.\n    updates: A list of (key, value) pairs with potential config updates. Values\n      that are not matched to a field in the model config will be ignored.\n\n  Returns:\n    Number of updates that are applied to the model config.\n  \"\"\"\n  applied_updates = 0\n  for k, v in updates:\n    if _apply_update(model_config, k, v):\n      applied_updates += 1\n      logging.info('Updated model config with %s=%s', k, str(v))\n  return applied_updates\n\n\ndef _apply_update(node, k, v):\n  \"\"\"Applies k, v updates to the given config node. See apply_updates.\"\"\"\n  while '__' in k:\n    parts = k.split('__', 2)\n    if len(parts) != 3:\n      return False\n    prefix, child_node_name, k = parts\n    if (prefix == _HPARAM_FEATURE_PREFIX and\n        isinstance(node, _HasFeatureConfigs)):\n      node = node.feature_config_by_name(child_node_name)\n    elif (prefix == _HPARAM_REGULARIZER_PREFIX and\n          isinstance(node, _HasRegularizerConfigs)):\n      node = node.regularizer_config_by_name(child_node_name)\n    else:\n      return False\n\n  if hasattr(node, k):\n    if isinstance(v, str):\n      current_value = getattr(node, k)\n      if current_value is None:\n        raise ValueError(\n            'Field `{}` has None value and can not be overridden by the '\n            'hparams string value `{}` since the type cannot be inferred. An '\n            'initial value must be set for the field to use string hparams.'\n            .format(k, v))\n      v = type(current_value)(v)\n\n    setattr(node, k, v)\n    return True\n\n  return False\n"
  },
  {
    "path": "tensorflow_lattice/python/configs_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for TFL model configuration library.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_lattice.python import categorical_calibration_layer\nfrom tensorflow_lattice.python import configs\nfrom tensorflow_lattice.python import lattice_layer\nfrom tensorflow_lattice.python import linear_layer\nfrom tensorflow_lattice.python import premade\nfrom tensorflow_lattice.python import pwl_calibration_layer\n\ntfl_custom_objects = {\n    'CalibratedLatticeEnsemble':\n        premade.CalibratedLatticeEnsemble,\n    'CalibratedLattice':\n        premade.CalibratedLattice,\n    'CalibratedLinear':\n        premade.CalibratedLinear,\n    'CategoricalCalibration':\n        categorical_calibration_layer.CategoricalCalibration,\n    'FeatureConfig':\n        configs.FeatureConfig,\n    'RegularizerConfig':\n        configs.RegularizerConfig,\n    'TrustConfig':\n        configs.TrustConfig,\n    'DominanceConfig':\n        configs.DominanceConfig,\n    'CalibratedLatticeEnsembleConfig':\n        configs.CalibratedLatticeEnsembleConfig,\n    'CalibratedLatticeConfig':\n        configs.CalibratedLatticeConfig,\n    'CalibratedLinearConfig':\n        configs.CalibratedLinearConfig,\n    'Lattice':\n        lattice_layer.Lattice,\n    'Linear':\n        linear_layer.Linear,\n    'PWLCalibration':\n        pwl_calibration_layer.PWLCalibration,\n}\n\n\nclass ConfigsTest(tf.test.TestCase):\n\n  def test_from_config(self):\n    feature_configs = [\n        configs.FeatureConfig(\n            name='feature_a',\n            pwl_calibration_input_keypoints='quantiles',\n            pwl_calibration_num_keypoints=8,\n            monotonicity=1,\n            pwl_calibration_clip_max=100,\n        ),\n        configs.FeatureConfig(\n            name='feature_b',\n            lattice_size=3,\n            unimodality='valley',\n            pwl_calibration_input_keypoints='uniform',\n            pwl_calibration_num_keypoints=5,\n            pwl_calibration_clip_min=130,\n            pwl_calibration_convexity='convex',\n            regularizer_configs=[\n                configs.RegularizerConfig(name='calib_hesian', l2=3e-3),\n            ],\n        ),\n        configs.FeatureConfig(\n            name='feature_c',\n            pwl_calibration_input_keypoints=[0.0, 0.5, 1.0],\n            reflects_trust_in=[\n                configs.TrustConfig(feature_name='feature_a'),\n                configs.TrustConfig(feature_name='feature_b', direction=-1),\n            ],\n            dominates=[\n                configs.DominanceConfig(\n                    feature_name='feature_d', dominance_type='monotonic'),\n            ],\n        ),\n        configs.FeatureConfig(\n            name='feature_d',\n            num_buckets=3,\n            vocabulary_list=['a', 'b', 'c'],\n            default_value=-1,\n        ),\n    ]\n    # First we test CalibratedLatticeEnsembleConfig\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=feature_configs,\n        lattices=[['feature_a', 'feature_b'], ['feature_c', 'feature_d']],\n        separate_calibrators=True,\n        regularizer_configs=[\n            configs.RegularizerConfig('torsion', l2=1e-4),\n        ],\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[0.0, 1.0])\n    model_config_copy = configs.CalibratedLatticeEnsembleConfig.from_config(\n        model_config.get_config(), tfl_custom_objects)\n    self.assertDictEqual(model_config.get_config(),\n                         model_config_copy.get_config())\n    # Next we test CalibratedLatticeConfig\n    model_config = configs.CalibratedLatticeConfig(\n        feature_configs=feature_configs,\n        regularizer_configs=[\n            configs.RegularizerConfig('torsion', l2=1e-4),\n        ],\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=8,\n        output_initialization='quantiles')\n    model_config_copy = configs.CalibratedLatticeConfig.from_config(\n        model_config.get_config(), tfl_custom_objects)\n    self.assertDictEqual(model_config.get_config(),\n                         model_config_copy.get_config())\n    # Last we test CalibratedLinearConfig\n    model_config = configs.CalibratedLinearConfig(\n        feature_configs=feature_configs,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-4),\n        ],\n        use_bias=True,\n        output_min=0.0,\n        output_max=None,\n        output_calibration=True,\n        output_initialization='uniform')\n    model_config_copy = configs.CalibratedLinearConfig.from_config(\n        model_config.get_config(), tfl_custom_objects)\n    self.assertDictEqual(model_config.get_config(),\n                         model_config_copy.get_config())\n\n  def test_updates(self):\n    model_config = configs.CalibratedLatticeConfig(\n        output_min=0,\n        regularizer_configs=[\n            configs.RegularizerConfig(name='torsion', l2=2e-3),\n        ],\n        feature_configs=[\n            configs.FeatureConfig(\n                name='feature_a',\n                pwl_calibration_input_keypoints='quantiles',\n                pwl_calibration_num_keypoints=8,\n                monotonicity=1,\n                pwl_calibration_clip_max=100,\n            ),\n            configs.FeatureConfig(\n                name='feature_b',\n                lattice_size=3,\n                unimodality='valley',\n                pwl_calibration_input_keypoints='uniform',\n                pwl_calibration_num_keypoints=5,\n                pwl_calibration_clip_min=130,\n                pwl_calibration_convexity='convex',\n                regularizer_configs=[\n                    configs.RegularizerConfig(name='calib_hessian', l2=3e-3),\n                ],\n            ),\n            configs.FeatureConfig(\n                name='feature_c',\n                pwl_calibration_input_keypoints=[0.0, 0.5, 1.0],\n                reflects_trust_in=[\n                    configs.TrustConfig(feature_name='feature_a'),\n                    configs.TrustConfig(feature_name='feature_b', direction=-1),\n                ],\n            ),\n            configs.FeatureConfig(\n                name='feature_d',\n                num_buckets=3,\n                vocabulary_list=['a', 'b', 'c'],\n                default_value=-1,\n            ),\n        ])\n\n    updates = [\n        # Update values can be passed in as numbers.\n        ('output_max', 1.0),  # update\n        ('regularizer__torsion__l2', 0.004),  # update\n        ('regularizer__calib_hessian__l1', 0.005),  # insert\n        ('feature__feature_a__lattice_size', 3),  # update\n        ('feature__feature_e__lattice_size', 4),  # insert\n        # Update values can be strings.\n        ('unrelated_hparams_not_affecting_config', 'unrelated'),\n        ('feature__feature_a__regularizer__calib_wrinkle__l1', '0.6'),  # insert\n        ('feature__feature_b__regularizer__calib_hessian__l1', '0.7'),  # update\n        ('yet__another__unrelated_config', '4'),\n    ]\n    self.assertEqual(configs.apply_updates(model_config, updates), 7)\n\n    model_config.feature_config_by_name('feature_a').monotonicity = 'none'\n    model_config.feature_config_by_name('feature_f').num_buckets = 4  # insert\n\n    feature_names = [\n        feature_config.name for feature_config in model_config.feature_configs\n    ]\n    expected_feature_names = [\n        'feature_a', 'feature_b', 'feature_c', 'feature_d', 'feature_e',\n        'feature_f'\n    ]\n    self.assertCountEqual(feature_names, expected_feature_names)\n\n    global_regularizer_names = [\n        regularizer_config.name\n        for regularizer_config in model_config.regularizer_configs\n    ]\n    expected_global_regularizer_names = ['torsion', 'calib_hessian']\n    self.assertCountEqual(global_regularizer_names,\n                          expected_global_regularizer_names)\n\n    self.assertEqual(model_config.output_max, 1.0)\n    self.assertEqual(\n        model_config.feature_config_by_name('feature_a').lattice_size, 3)\n    self.assertEqual(\n        model_config.feature_config_by_name(\n            'feature_b').pwl_calibration_convexity, 'convex')\n    self.assertEqual(\n        model_config.feature_config_by_name('feature_e').lattice_size, 4)\n    self.assertEqual(\n        model_config.regularizer_config_by_name('torsion').l2, 0.004)\n    self.assertEqual(\n        model_config.regularizer_config_by_name('calib_hessian').l1, 0.005)\n    self.assertEqual(\n        model_config.feature_config_by_name(\n            'feature_a').regularizer_config_by_name('calib_wrinkle').l1, 0.6)\n    self.assertEqual(\n        model_config.feature_config_by_name(\n            'feature_b').regularizer_config_by_name('calib_hessian').l1, 0.7)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/internal_utils.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Internal helpers shared by multiple modules in TFL.\n\nNote that this module is not expected to be used by TFL users, and that it is\nnot exposed in the TFL package.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport tensorflow as tf\n\n\ndef _topological_sort(key_less_than_values):\n  \"\"\"Topological sort for monotonicities.\n\n  Args:\n    key_less_than_values: A defaultdict from index to a list of indices, such\n      that for j in key_less_than_values[i] we must have output(i) <= output(j).\n\n  Returns:\n    A topologically sorted list of indices.\n\n  Raises:\n    ValueError: If monotonicities are circular.\n  \"\"\"\n  all_values = set()\n  for values in key_less_than_values.values():\n    all_values.update(values)\n\n  q = [k for k in key_less_than_values if k not in all_values]\n  if not q:\n    raise ValueError(\n        \"Circular monotonicity constraints: {}\".format(key_less_than_values))\n\n  result = []\n  seen = set()\n  while q:\n    v = q[-1]\n    seen.add(v)\n    expand = [x for x in key_less_than_values[v] if x not in seen]\n    if not expand:\n      result = [v] + result\n      q.pop()\n    else:\n      q.append(expand[0])\n\n  return result\n\n\ndef _min_projection(weights, sorted_indices, key_less_than_values, step):\n  \"\"\"Returns an approximate partial min projection with the given step_size.\n\n  Args:\n    weights: A list of tensors of shape `(units,)` to be approximatly projected\n      based on the monotonicity constraints.\n    sorted_indices: Topologically sorted list of indices based on the\n      monotonicity constraints.\n    key_less_than_values: A defaultdict from index to a list of indices, such\n      that for `j` in `key_less_than_values[i]` we must have `weight[i] <=\n      weight[j]`.\n    step: A value defining if we should apply a full projection (`step == 1`) or\n      a partial projection (`step < 1`).\n\n  Returns:\n    Projected list of tensors.\n  \"\"\"\n  projected_weights = list(weights)  # copy\n  for i in sorted_indices[::-1]:\n    if key_less_than_values[i]:\n      min_projection = projected_weights[i]\n      for j in key_less_than_values[i]:\n        min_projection = tf.minimum(min_projection, projected_weights[j])\n      if step == 1:\n        projected_weights[i] = min_projection\n      else:\n        projected_weights[i] = (\n            step * min_projection + (1 - step) * projected_weights[i])\n  return projected_weights\n\n\ndef _max_projection(weights, sorted_indices, key_greater_than_values, step):\n  \"\"\"Returns an approximate partial max projection with the given step_size.\n\n  Args:\n    weights: A list of tensors of shape `(units,)` to be approximatly projected\n      based on the monotonicity constraints.\n    sorted_indices: Topologically sorted list of indices based on the\n      monotonicity constraints.\n    key_greater_than_values: A defaultdict from index to a list of indices,\n      indicating that for index `j` in `key_greater_than_values[i]` we must have\n      `weight[i] >= weight[j]`.\n    step: A value defining if we should apply a full projection (`step == 1`) or\n      a partial projection (`step < 1`).\n\n  Returns:\n    Projected list of tensors.\n  \"\"\"\n  projected_weights = list(weights)  # copy\n  for i in sorted_indices:\n    if key_greater_than_values[i]:\n      max_projection = projected_weights[i]\n      for j in key_greater_than_values[i]:\n        max_projection = tf.maximum(max_projection, projected_weights[j])\n      if step == 1:\n        projected_weights[i] = max_projection\n      else:\n        projected_weights[i] = (\n            step * max_projection + (1 - step) * projected_weights[i])\n  return projected_weights\n\n\ndef approximately_project_categorical_partial_monotonicities(\n    weights, monotonicities):\n  \"\"\"Returns an approximation L2 projection for categorical monotonicities.\n\n  Categorical monotonocities are monotonicity constraints applied to the real\n  values that are mapped from categorical inputs. Each monotonicity constraint\n  is specified by a pair of categorical input indices. The projection is also\n  used to constrain pairs of coefficients in linear models.\n\n  Args:\n    weights: Tensor of weights to be approximately projected based on the\n      monotonicity constraints.\n    monotonicities: List of pairs of indices `(i, j)`, indicating constraint\n      `weights[i] <= weights[j]`.\n  \"\"\"\n  key_less_than_values = collections.defaultdict(list)\n  key_greater_than_values = collections.defaultdict(list)\n  for i, j in monotonicities:\n    key_less_than_values[i].append(j)\n    key_greater_than_values[j].append(i)\n\n  sorted_indices = _topological_sort(key_less_than_values)\n\n  projected_weights = tf.unstack(weights)\n\n  # A 0.5 min projection followed by a full max projection.\n  projected_weights_min_max = _min_projection(projected_weights, sorted_indices,\n                                              key_less_than_values, 0.5)\n  projected_weights_min_max = _max_projection(projected_weights_min_max,\n                                              sorted_indices,\n                                              key_greater_than_values, 1)\n  projected_weights_min_max = tf.stack(projected_weights_min_max)\n\n  # A 0.5 max projection followed by a full min projection.\n  projected_weights_max_min = _max_projection(projected_weights, sorted_indices,\n                                              key_greater_than_values, 0.5)\n  projected_weights_max_min = _min_projection(projected_weights_max_min,\n                                              sorted_indices,\n                                              key_less_than_values, 1)\n  projected_weights_max_min = tf.stack(projected_weights_max_min)\n\n  # Take the average of the two results to avoid sliding to one direction.\n  projected_weights = (projected_weights_min_max +\n                       projected_weights_max_min) / 2\n  return projected_weights\n"
  },
  {
    "path": "tensorflow_lattice/python/internal_utils_test.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Tensorflow Lattice utility functions.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import internal_utils\n\n\nclass InternalUtilsTest(parameterized.TestCase, tf.test.TestCase):\n\n  def _ResetAllBackends(self):\n    tf.compat.v1.reset_default_graph()\n\n  @parameterized.parameters(\n      ([3., 4.], [(0, 1)], [3., 4.]), ([4., 3.], [(0, 1)], [3.5, 3.5]),\n      ([1., 0.], [(0, 1)], [0.5, 0.5]), ([-1., 0.], [(1, 0)], [-0.5, -0.5]),\n      ([4., 3., 2., 1., 0.], [(0, 1), (1, 2), (2, 3),\n                              (3, 4)], [2., 2., 2., 2., 2.]))\n  def testApproximatelyProjectCategoricalPartialMonotonicities(\n      self, weights, monotonicities, expected_projected_weights):\n    self._ResetAllBackends()\n    weights = tf.Variable(weights)\n    projected_weights = (\n        internal_utils.approximately_project_categorical_partial_monotonicities(\n            weights, monotonicities))\n    self.evaluate(tf.compat.v1.global_variables_initializer())\n    self.assertAllClose(\n        self.evaluate(projected_weights), np.array(expected_projected_weights))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/kronecker_factored_lattice_layer.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Kronecker-Factored Lattice layer with monotonicity constraints.\n\nKeras implementation of tensorflow Kronecker-Factored Lattice layer. This layer\ntakes one or more d-dimensional input(s) and combines them using a\nKronecker-Factored Lattice function, satisfying monotonicity constraints if\nspecified.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport inspect\n\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\nfrom . import kronecker_factored_lattice_lib as kfl_lib\nfrom . import utils\n\nDIMS_NAME = \"dims\"\nKFL_SCALE_NAME = \"kronecker_factored_lattice_scale\"\nKFL_BIAS_NAME = \"kronecker_factored_lattice_bias\"\nKFL_KERNEL_NAME = \"kronecker_factored_lattice_kernel\"\nLATTICE_SIZES_NAME = \"lattice_sizes\"\nNUM_TERMS_NAME = \"num_terms\"\nUNITS_NAME = \"units\"\n\n\n# TODO: add support for different lattice_sizes for each input\n# dimension.\nclass KroneckerFactoredLattice(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Kronecker-Factored Lattice layer.\n\n  A Kronecker-Factored Lattice is a reparameterization of a Lattice using\n  kronecker-facotrization, which gives us linear time and space complexity.\n  While the underlying representation is different, the input-output behavior\n  remains the same.\n\n  A Kronecker-Factored Lattice consists of 'units' lattices. Each unit computes\n  the function described below on a distinct 'dims'-dimensional vector x taken\n  from the input tensor. Each unit has its own set of parameters. The function\n  each unit computes is given by:\n\n  f(x) = b + (1/num_terms) * sum_{t=1}^{num_terms} scale_t * prod_{d=1}^{dims} PLF(x[d];w[d])\n\n  where bias and each scale_t are scalar parameters, w[d] is a\n  'lattice_size'-dimensional vector of parameters, and  PLF(;w) denotes the\n  one-dimensional piecewise-linear function with domain [0, lattice_sizes-1]\n  whose graph consists of lattice_sizes-1 linear segments interpolating the\n  points (i, w[i]), for i=0,1,...,lattice_size-1.\n\n  There is currently one type of constraint on the shape of the learned\n  function.\n\n  * **Monotonicity:** constrains the function to be increasing in the\n    corresponding dimension. To achieve decreasing monotonicity, either pass the\n    inputs through a `tfl.layers.PWLCalibration` with `decreasing` monotonicity,\n    or manually reverse the inputs as `lattice_size - 1 - inputs`.\n\n  There are upper and lower bound constraints on the output.\n\n  Input shape:\n    - if `units == 1`: tensor of shape: `(batch_size, ..., dims)`\n      or list of `dims` tensors of same shape: `(batch_size, ..., 1)`\n    - if `units > 1`: tensor of shape: `(batch_size, ..., units, dims)` or list\n      of `dims` tensors of same shape: `(batch_size, ..., units, 1)`\n\n    A typical shape is: `(batch_size, len(monotonicities))`\n\n  Output shape:\n    Tensor of shape: `(batch_size, ..., units)`\n\n  Attributes:\n    - All `__init__` arguments.\n    scale: A tensor of shape `(units, num_terms)`. Contains the `scale_t`\n      parameter for each unit for each term.\n    bias: A tensor of shape `(units)`. Contains the `b` parameter for each unit.\n    kernel: The `w` weights parameter of the Kronecker-Factored Lattice of\n      shape: `(1, lattice_sizes, units * dims, num_terms)`. Note that the kernel\n      is unit-major in its second to last dimension.\n\n  Example:\n\n  ```python\n  kfl = tfl.layers.KroneckerFactoredLattice(\n      # Number of vertices along each dimension.\n      lattice_sizes=2,\n      # Number of output units.\n      units=2,\n      # Number of independently trained submodels per unit, the outputs\n      # of which are averaged to get the final output.\n      num_terms=4,\n      # You can specify monotonicity constraints.\n      monotonicities=['increasing', 'none', 'increasing', 'increasing',\n                      'increasing', 'increasing', 'increasing'])\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               lattice_sizes,\n               units=1,\n               num_terms=2,\n               monotonicities=None,\n               output_min=None,\n               output_max=None,\n               clip_inputs=True,\n               kernel_initializer=\"kfl_random_monotonic_initializer\",\n               scale_initializer=\"scale_initializer\",\n               **kwargs):\n    # pyformat: disable\n    \"\"\"Initializes an instance of `KroneckerFactoredLattice`.\n\n    Args:\n      lattice_sizes: Number of vertices per dimension (minimum is 2).\n      units: Output dimension of the layer. See class comments for details.\n      num_terms: Number of independently trained submodels per unit, the outputs\n        of which are averaged to get the final output.\n      monotonicities: None or list or tuple of same length as input dimension of\n        {'none', 'increasing', 0, 1} which specifies if the model output should\n        be monotonic in the corresponding feature, using 'increasing' or 1 to\n        indicate increasing monotonicity and 'none' or 0 to indicate no\n        monotonicity constraints.\n      output_min: None or lower bound of the output.\n      output_max: None or upper bound of the output.\n      clip_inputs: If inputs should be clipped to the input range of the\n        Kronecker-Factored Lattice.\n      kernel_initializer: None or one of:\n        - `'kfl_random_monotonic_initializer'`: initializes parameters as uniform\n          random functions that are monotonic in monotonic dimensions.\n        - Any Keras initializer object.\n      scale_initializer: None or one of:\n        - `'scale_initializer'`: Initializes scale depending on output_min and\n          output_max. If both output_min and output_max are set, scale is\n          initialized to half their difference, alternating signs for each term.\n          If only output_min is set, scale is initialized to 1 for each term. If\n          only output_max is set, scale is initialized to -1 for each term.\n          Otherwise scale is initialized to alternate between 1 and -1 for each\n          term.\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: If layer hyperparameters are invalid.\n    \"\"\"\n    # pyformat: enable\n    kfl_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        units=units,\n        num_terms=num_terms,\n        output_min=output_min,\n        output_max=output_max)\n    super(KroneckerFactoredLattice, self).__init__(**kwargs)\n\n    self.lattice_sizes = lattice_sizes\n    self.units = units\n    self.num_terms = num_terms\n    self.monotonicities = monotonicities\n    self.output_min = output_min\n    self.output_max = output_max\n    self.clip_inputs = clip_inputs\n\n    self.kernel_initializer = create_kernel_initializer(\n        kernel_initializer_id=kernel_initializer,\n        monotonicities=self.monotonicities,\n        output_min=self.output_min,\n        output_max=self.output_max)\n\n    self.scale_initializer = create_scale_initializer(\n        scale_initializer_id=scale_initializer,\n        output_min=self.output_min,\n        output_max=self.output_max)\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    kfl_lib.verify_hyperparameters(\n        units=self.units,\n        input_shape=input_shape,\n        monotonicities=self.monotonicities)\n    # input_shape: (batch, ..., units, dims)\n    if isinstance(input_shape, list):\n      dims = len(input_shape)\n    else:\n      dims = input_shape.as_list()[-1]\n\n    if self.output_min is not None or self.output_max is not None:\n      scale_constraints = ScaleConstraints(\n          output_min=self.output_min, output_max=self.output_max)\n    else:\n      scale_constraints = None\n    self.scale = self.add_weight(\n        KFL_SCALE_NAME,\n        shape=[self.units, self.num_terms],\n        initializer=self.scale_initializer,\n        constraint=scale_constraints,\n        dtype=self.dtype)\n    self.bias = self.add_weight(\n        KFL_BIAS_NAME,\n        shape=[self.units],\n        initializer=BiasInitializer(self.output_min, self.output_max),\n        trainable=(self.output_min is None and self.output_max is None),\n        dtype=self.dtype)\n\n    if (self.monotonicities or self.output_min is not None or\n        self.output_max is not None):\n      constraints = KroneckerFactoredLatticeConstraints(\n          units=self.units,\n          scale=self.scale,\n          monotonicities=self.monotonicities,\n          output_min=self.output_min,\n          output_max=self.output_max)\n    else:\n      constraints = None\n\n    # Note that the first dimension of shape is 1 to work with\n    # tf.nn.depthwise_conv2d. We also provide scale to the __call__ method\n    # of the initializer using partial functions if it accepts scale.\n    parameters = inspect.signature(self.kernel_initializer).parameters.keys()\n    if \"scale\" in parameters:\n      # initial_value needs the lambda because it is a class property and the\n      # second and third arguments to tf.cond should be functions,\n      # but read_value is already a function, so the lambda is not needed.\n      kernel_initializer = functools.partial(\n          self.kernel_initializer,\n          scale=tf.cond(\n              tf.compat.v1.is_variable_initialized(self.scale),\n              self.scale.read_value,\n              lambda: self.scale.initial_value))\n    else:\n      kernel_initializer = self.kernel_initializer\n    self.kernel = self.add_weight(\n        KFL_KERNEL_NAME,\n        shape=[1, self.lattice_sizes, self.units * dims, self.num_terms],\n        initializer=kernel_initializer,\n        constraint=constraints,\n        dtype=self.dtype)\n\n    self._final_kernel_constraints = KroneckerFactoredLatticeConstraints(\n        units=self.units,\n        scale=self.scale,\n        monotonicities=self.monotonicities,\n        output_min=self.output_min,\n        output_max=self.output_max)\n\n    self._final_scale_constraints = ScaleConstraints(\n        output_min=self.output_min, output_max=self.output_max)\n\n    # These tensors are meant for book keeping. Note that this slightly\n    # increases the size of the graph.\n    self.lattice_sizes_tensor = tf.constant(\n        self.lattice_sizes, dtype=tf.int32, name=LATTICE_SIZES_NAME)\n    self.units_tensor = tf.constant(\n        self.units, dtype=tf.int32, name=UNITS_NAME)\n    self.dims_tensor = tf.constant(dims, dtype=tf.int32, name=DIMS_NAME)\n    self.num_terms_tensor = tf.constant(\n        self.num_terms, dtype=tf.int32, name=NUM_TERMS_NAME)\n\n    super(KroneckerFactoredLattice, self).build(input_shape)\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    return kfl_lib.evaluate_with_hypercube_interpolation(\n        inputs=inputs,\n        scale=self.scale,\n        bias=self.bias,\n        kernel=self.kernel,\n        units=self.units,\n        num_terms=self.num_terms,\n        lattice_sizes=self.lattice_sizes,\n        clip_inputs=self.clip_inputs)\n\n  def compute_output_shape(self, input_shape):\n    \"\"\"Standard Keras compute_output_shape() method.\"\"\"\n    if isinstance(input_shape, list):\n      input_shape = input_shape[0]\n    if self.units == 1:\n      return tuple(input_shape[:-1]) + (1,)\n    else:\n      # Second to last dimension must be equal to 'units'. Nothing to append.\n      return input_shape[:-1]\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"units\": self.units,\n        \"num_terms\": self.num_terms,\n        \"monotonicities\": self.monotonicities,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"clip_inputs\": self.clip_inputs,\n        \"kernel_initializer\":\n            keras.initializers.serialize(\n                self.kernel_initializer, use_legacy_format=True),\n        \"scale_initializer\":\n            keras.initializers.serialize(\n                self.scale_initializer, use_legacy_format=True),\n    }  # pyformat: disable\n    config.update(super(KroneckerFactoredLattice, self).get_config())\n    return config\n\n  # TODO: can we remove this now that we always project at every step?\n  def finalize_constraints(self):\n    \"\"\"Ensures layers weights strictly satisfy constraints.\n\n    Applies approximate projection to strictly satisfy specified constraints.\n\n    Returns:\n      In eager mode directly updates kernel and scale and returns the variables\n      which store them. In graph mode returns a `group` op containing the\n      `assign_add` ops which have to be executed to update the kernel and scale.\n    \"\"\"\n    finalize_kernel = self.kernel.assign_add(\n        self._final_kernel_constraints(self.kernel) - self.kernel)\n    finalize_scale = self.scale.assign_add(\n        self._final_scale_constraints(self.scale) - self.scale)\n    return tf.group([finalize_kernel, finalize_scale])\n\n  def assert_constraints(self, eps=1e-6):\n    \"\"\"Asserts that weights satisfy all constraints.\n\n    In graph mode builds and returns list of assertion ops.\n    In eager mode directly executes assertions.\n\n    Args:\n      eps: allowed constraints violation.\n\n    Returns:\n      List of assertion ops in graph mode or immediately asserts in eager mode.\n    \"\"\"\n    return kfl_lib.assert_constraints(\n        weights=self.kernel,\n        units=self.units,\n        scale=self.scale,\n        monotonicities=utils.canonicalize_monotonicities(\n            self.monotonicities, allow_decreasing=False),\n        output_min=self.output_min,\n        output_max=self.output_max,\n        eps=eps)\n\n\ndef create_kernel_initializer(kernel_initializer_id,\n                              monotonicities,\n                              output_min,\n                              output_max,\n                              init_min=None,\n                              init_max=None):\n  \"\"\"Returns a kernel Keras initializer object from its id.\n\n  This function is used to convert the 'kernel_initializer' parameter in the\n  constructor of tfl.layers.KroneckerFactoredLattice into the corresponding\n  initializer object.\n\n  Args:\n    kernel_initializer_id: See the documentation of the 'kernel_initializer'\n      parameter in the constructor of `tfl.layers.KroneckerFactoredLattice`.\n    monotonicities: See the documentation of the same parameter in the\n      constructor of `tfl.layers.KroneckerFactoredLattice`.\n    output_min: See the documentation of the same parameter in the constructor\n      of `tfl.layers.KroneckerFactoredLattice`.\n    output_max: See the documentation of the same parameter in the constructor\n      of `tfl.layers.KroneckerFactoredLattice`.\n    init_min: None or lower bound of kernel initialization. If set, init_max\n      must also be set. Ignored if kernel_initializer_id is a Keras object.\n    init_max: None or upper bound of kernel initialization. If set, init_min\n      must also be set. Ignored if kernel_initializer_id is a Keras object.\n\n  Returns:\n    The Keras initializer object for the `tfl.layers.KroneckerFactoredLattice`\n    kernel variable.\n\n  Raises:\n    ValueError: If only one of init_{min/max} is set.\n  \"\"\"\n  if init_min is None and init_max is None:\n    init_min, init_max = kfl_lib.default_init_params(output_min, output_max)\n  elif init_min is not None and init_max is not None:\n    # We have nothing to set here.\n    pass\n  else:\n    raise ValueError(\"Both or neither of init_{min/max} must be set\")\n\n  # Construct initializer.\n  if kernel_initializer_id in [\n      \"kfl_random_monotonic_initializer\", \"KFLRandomMonotonicInitializer\"\n  ]:\n    return KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, init_min=init_min, init_max=init_max)\n  else:\n    # This is needed for Keras deserialization logic to be aware of our custom\n    # objects.\n    with keras.utils.custom_object_scope({\n        \"KFLRandomMonotonicInitializer\": KFLRandomMonotonicInitializer,\n    }):\n      return keras.initializers.get(kernel_initializer_id)\n\n\ndef create_scale_initializer(scale_initializer_id, output_min, output_max):\n  \"\"\"Returns a scale Keras initializer object from its id.\n\n  This function is used to convert the 'scale_initializer' parameter in the\n  constructor of tfl.layers.KroneckerFactoredLattice into the corresponding\n  initializer object.\n\n  Args:\n    scale_initializer_id: See the documentation of the 'scale_initializer'\n      parameter in the constructor of `tfl.layers.KroneckerFactoredLattice`.\n    output_min: See the documentation of the same parameter in the constructor\n      of `tfl.layers.KroneckerFactoredLattice`.\n    output_max: See the documentation of the same parameter in the constructor\n      of `tfl.layers.KroneckerFactoredLattice`.\n\n  Returns:\n    The Keras initializer object for the `tfl.layers.KroneckerFactoredLattice`\n    scale variable.\n  \"\"\"\n  # Construct initializer.\n  if scale_initializer_id in [\"scale_initializer\", \"ScaleInitializer\"]:\n    return ScaleInitializer(output_min=output_min, output_max=output_max)\n  else:\n    # This is needed for Keras deserialization logic to be aware of our custom\n    # objects.\n    with keras.utils.custom_object_scope({\n        \"ScaleInitializer\": ScaleInitializer,\n    }):\n      return keras.initializers.get(scale_initializer_id)\n\n\nclass KFLRandomMonotonicInitializer(keras.initializers.Initializer):\n  # pyformat: disable\n  \"\"\"Initializes a `tfl.layers.KroneckerFactoredLattice` as random monotonic.\"\"\"\n  # pyformat: enable\n\n  def __init__(self, monotonicities, init_min=0.5, init_max=1.5, seed=None):\n    \"\"\"Initializes an instance of `KFLRandomMonotonicInitializer`.\n\n    Args:\n      monotonicities: Monotonic dimensions for initialization. Does not need to\n        match `monotonicities` of `tfl.layers.KroneckerFactoredLattice`.\n      init_min: The lower bound on the range of initialized weights.\n      init_max: The upper bound on the range of initialized weights.\n      seed: A Python integer. Used to create a random seed for the distribution.\n    \"\"\"\n    self.monotonicities = monotonicities\n    self.init_min = init_min\n    self.init_max = init_max\n    self.seed = seed\n\n  def __call__(self, shape, scale, dtype=None, **kwargs):\n    \"\"\"Returns weights of `tfl.layers.KroneckerFactoredLattice` layer.\n\n    Args:\n      shape: Must be: `(1, lattice_sizes, units * dims, num_terms)`.\n      scale: Scale variable of shape: `(units, num_terms)`.\n      dtype: Standard Keras initializer param.\n      **kwargs: Other args passed to `keras.initializers.Initializer` __call__\n        method.\n    \"\"\"\n    return kfl_lib.kfl_random_monotonic_initializer(\n        shape=shape,\n        scale=scale,\n        monotonicities=utils.canonicalize_monotonicities(\n            self.monotonicities, allow_decreasing=False),\n        init_min=self.init_min,\n        init_max=self.init_max,\n        dtype=dtype,\n        seed=self.seed)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serializaion.\"\"\"\n    config = {\n        \"monotonicities\": self.monotonicities,\n        \"init_min\": self.init_min,\n        \"init_max\": self.init_max,\n        \"seed\": self.seed,\n    }  # pyformat: disable\n    return config\n\n\nclass ScaleInitializer(keras.initializers.Initializer):\n  # pyformat: disable\n  \"\"\"Initializes scale depending on output_min and output_max.\n\n  If both output_min and output_max are set, scale is initialized to half their\n  difference, alternating signs for each term. If only output_min is set, scale\n  is initialized to 1 for each term. If only output_max is set, scale is\n  initialized to -1 for each term. Otherwise scale is initialized to alternate\n  between 1 and -1 for each term.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, output_min, output_max):\n    \"\"\"Initializes an instance of `ScaleInitializer`.\n\n    Args:\n      output_min: None or minimum layer output.\n      output_max: None or maximum layer output.\n    \"\"\"\n    self.output_min = output_min\n    self.output_max = output_max\n\n  def __call__(self, shape, dtype=None, **kwargs):\n    \"\"\"Returns weights of `tfl.layers.KroneckerFactoredLattice` scale.\n\n    Args:\n      shape: Must be: `(units, num_terms)`.\n      dtype: Standard Keras initializer param.\n      **kwargs: Other args passed to `keras.initializers.Initializer` __call__\n        method.\n    \"\"\"\n    units, num_terms = shape\n    return kfl_lib.scale_initializer(\n        units=units,\n        num_terms=num_terms,\n        output_min=self.output_min,\n        output_max=self.output_max)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serializaion.\"\"\"\n    config = {\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n    }  # pyformat: disable\n    return config\n\n\nclass BiasInitializer(keras.initializers.Initializer):\n  # pyformat: disable\n  \"\"\"Initializes bias depending on output_min and output_max.\n\n  If both output_min and output_max are set, bias is initialized to their\n  average. If only output_min is set, bias is initialized to output_min. If only\n  output_max is set, bias is initialized to output_max. Otherwise bias is\n  initialized to zeros.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, output_min, output_max):\n    \"\"\"Initializes an instance of `BiasInitializer`.\n\n    Args:\n      output_min: None or minimum layer output.\n      output_max: None or maximum layer output.\n    \"\"\"\n    self.output_min = output_min\n    self.output_max = output_max\n\n  def __call__(self, shape, dtype=None, **kwargs):\n    \"\"\"Returns weights of `tfl.layers.KroneckerFactoredLattice` bias.\n\n    Args:\n      shape: Must be: `(units, num_terms)`.\n      dtype: Standard Keras initializer param.\n      **kwargs: Other args passed to `keras.initializers.Initializer` __call__\n        method.\n    \"\"\"\n    return kfl_lib.bias_initializer(\n        units=shape[0],\n        output_min=self.output_min,\n        output_max=self.output_max,\n        dtype=dtype)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serializaion.\"\"\"\n    config = {\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n    }  # pyformat: disable\n    return config\n\n\nclass KroneckerFactoredLatticeConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Constraints for `tfl.layers.KroneckerFactoredLattice` layer.\n\n  Applies all constraints to the Kronecker-Factored Lattice weights. See\n  `tfl.layers.KroneckerFactoredLattice` for more details.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               units,\n               scale,\n               monotonicities=None,\n               output_min=None,\n               output_max=None):\n    \"\"\"Initializes an instance of `KroneckerFactoredLatticeConstraints`.\n\n    Args:\n      units: Same meaning as corresponding parameter of\n        `KroneckerFactoredLattice`.\n      scale: Scale variable of shape: `(units, num_terms)`.\n      monotonicities: Same meaning as corresponding parameter of\n        `KroneckerFactoredLattice`.\n      output_min: Same meaning as corresponding parameter of\n        `KroneckerFactoredLattice`.\n      output_max: Same meaning as corresponding parameter of\n        `KroneckerFactoredLattice`.\n    \"\"\"\n    self.units = units\n    self.scale = scale\n    self.monotonicities = utils.canonicalize_monotonicities(\n        monotonicities, allow_decreasing=False)\n    self.num_constraint_dims = utils.count_non_zeros(self.monotonicities)\n    self.output_min = output_min\n    self.output_max = output_max\n\n  def __call__(self, w):\n    \"\"\"Applies constraints to `w`.\n\n    Args:\n      w: Kronecker-Factored Lattice weights tensor of shape: `(1, lattice_sizes,\n        units * dims, num_terms)`.\n\n    Returns:\n      Constrained and projected w.\n    \"\"\"\n    if self.num_constraint_dims:\n      w = kfl_lib.finalize_weight_constraints(\n          w,\n          units=self.units,\n          scale=self.scale,\n          monotonicities=self.monotonicities,\n          output_min=self.output_min,\n          output_max=self.output_max)\n    return w\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"units\": self.units,\n        \"scale\": self.scale,\n        \"monotonicities\": self.monotonicities,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n    }  # pyformat: disable\n\n\nclass ScaleConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Constraints for `tfl.layers.KroneckerFactoredLattice` scale.\n\n  Constraints the scale variable to be between\n  `[output_min-output_max, output_max-output_min]` such that the final output\n  of the layer is within the desired `[output_min, output_max]` range, assuming\n  bias is properly fixed to be `output_min`.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, output_min=None, output_max=None):\n    \"\"\"Initializes an instance of `ScaleConstraints`.\n\n    Args:\n      output_min: Same meaning as corresponding parameter of\n        `KroneckerFactoredLattice`.\n      output_max: Same meaning as corresponding parameter of\n        `KroneckerFactoredLattice`.\n    \"\"\"\n    self.output_min = output_min\n    self.output_max = output_max\n\n  def __call__(self, scale):\n    \"\"\"Applies constraints to `scale`.\n\n    Args:\n      scale: Kronecker-Factored Lattice scale tensor of shape: `(units,\n        num_terms)`.\n\n    Returns:\n      Constrained and clipped scale.\n    \"\"\"\n    if self.output_min is not None or self.output_max is not None:\n      scale = kfl_lib.finalize_scale_constraints(\n          scale, output_min=self.output_min, output_max=self.output_max)\n    return scale\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n    }  # pyformat: disable\n"
  },
  {
    "path": "tensorflow_lattice/python/kronecker_factored_lattice_lib.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Algorithm implementations required for Kronecker-Factored Lattice layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom . import utils\nimport numpy as np\nimport tensorflow as tf\n\n\ndef custom_reduce_prod(t, axis):\n  \"\"\"tf.reduce_prod(t, axis) with faster custom gradient.\n\n  Shows comparable speed on CPU, up to 2x speed up on GPU, and 7x on TPU.\n\n  Args:\n    t: The tensor to reduce.\n    axis: The dimension to reduce.\n\n  Returns:\n    prod(t) and grad(prod(t))\n  \"\"\"\n\n  @tf.custom_gradient\n  def fn(t):\n    # Can safely use the built in forward op.\n    fwd = tf.reduce_prod(t, axis=axis)\n\n    def grad_fn(dy):\n      \"\"\"Computes the gradient function.\n\n      Args:\n        dy: The gradient flowing into the output of this function.\n\n      Returns:\n        The gradient flowing out through the input of this function.\n      \"\"\"\n      is_zero = tf.cast(tf.equal(t, 0), tf.float32)\n      num_zeros = tf.reduce_sum(is_zero, axis=axis)\n\n      # If the product contains no zero elements, then simply divide the\n      # product by each element to determine the partial gradients.\n      grad0 = tf.math.divide_no_nan(tf.expand_dims(fwd, axis=axis), t)\n\n      # If the product contained one zero element, then compute the gradient\n      # for that zero element. The gradients for other elements should be\n      # zero.\n      prod = tf.reduce_prod(t + is_zero, axis=axis)\n      grad1 = tf.cast(tf.equal(num_zeros, 1), tf.float32) * prod\n      grad1 = tf.expand_dims(grad1, axis=axis) * is_zero\n\n      return tf.expand_dims(dy, axis=axis) * (grad0 + grad1)\n\n    return fwd, grad_fn\n\n  return fn(t)\n\n\ndef evaluate_with_hypercube_interpolation(inputs, scale, bias, kernel, units,\n                                          num_terms, lattice_sizes,\n                                          clip_inputs):\n  \"\"\"Evaluates a Kronecker-Factored Lattice using hypercube interpolation.\n\n  Kronecker-Factored Lattice function is the product of the piece-wise linear\n  interpolation weights for each dimension of the input.\n\n  Args:\n    inputs: Tensor representing points to apply lattice interpolation to. If\n      units = 1, tensor should be of shape: `(batch_size, ..., dims)` or list of\n        `dims` tensors of same shape `(batch_size, ..., 1)`. If units > 1,\n        tensor\n      should be of shape: `(batch_size, ..., units, dims)` or list of `dims`\n        tensors of same shape `(batch_size, ..., units, 1)`. A typical shape is\n        `(batch_size, dims)`.\n    scale: Kronecker-Factored Lattice scale of shape `(units, num_terms)`.\n    bias: Kronecker-Factored Lattice bias of shape `(units)`.\n    kernel: Kronecker-Factored Lattice kernel of shape\n      `(1, lattice_sizes, units * dims, num_terms)`.\n    units: Output dimension of the Kronecker-Factored Lattice.\n    num_terms: Number of independently trained submodels per unit, the outputs\n      of which are averaged to get the final output.\n    lattice_sizes: Number of vertices per dimension.\n    clip_inputs: If inputs should be clipped to the input range of the\n      Kronecker-Factored Lattice.\n\n  Returns:\n    Tensor of shape: `(batch_size, ..., units)`.\n  \"\"\"\n  # Convert list of tensors to single tensor object.\n  if isinstance(inputs, list):\n    inputs = tf.concat(inputs, axis=-1)\n  if clip_inputs:\n    inputs = tf.clip_by_value(inputs, 0.0, lattice_sizes - 1.0)\n\n  inputs_shape = inputs.get_shape().as_list()\n  dims = inputs_shape[-1]\n  # Compute total dimension size before units excluding batch to squeeze into\n  # one axis.\n  idx = -1 if units == 1 else -2\n  rows = int(np.prod(inputs_shape[1:idx]))\n  inputs = tf.reshape(inputs, [-1, rows, units * dims])\n\n  # interpolation_weights.shape: (batch, rows, lattice_sizes, units * dims).\n  # interpolation_weights[m,n,i,j] should be the interpolation weight of the\n  # (m,n,j) input in the i'th vertex, i.e. 0 if dist(input[m,n,j], i) >= 1,\n  # otherwise 1 - dist(input[m,n,j], i), where `dist(...)` denotes the Euclidean\n  # distance between scalars.\n  if lattice_sizes == 2:\n    interpolation_weights = tf.stack([1 - inputs, inputs], axis=-2)\n  else:\n    vertices = tf.constant(\n        list(range(lattice_sizes)),\n        shape=(lattice_sizes, 1),\n        dtype=inputs.dtype)\n    interpolation_weights = vertices - tf.expand_dims(inputs, axis=-2)\n    interpolation_weights = 1 - tf.minimum(tf.abs(interpolation_weights), 1)\n\n  # dotprod.shape: (batch, rows, 1, units * dims * num_terms)\n  dotprod = tf.nn.depthwise_conv2d(\n      interpolation_weights, kernel, [1, 1, 1, 1], padding=\"VALID\")\n  dotprod = tf.reshape(dotprod, [-1, rows, units, dims, num_terms])\n\n  prod = custom_reduce_prod(dotprod, axis=-2)\n\n  results = scale * prod\n  # Average across terms for each unit.\n  results = tf.reduce_mean(results, axis=-1)\n  results = results + bias\n\n  # results.shape: (batch, rows, units)\n  results_shape = [-1] + inputs_shape[1:-1]\n  if units == 1:\n    results_shape.append(1)\n  results = tf.reshape(results, results_shape)\n  return results\n\n\ndef default_init_params(output_min, output_max):\n  \"\"\"Returns default initialization bounds depending on layer output bounds.\n\n  Args:\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n  \"\"\"\n  if output_min is None and output_max is None:\n    return 0.5, 1.5\n  else:\n    return 0.0, 1.0\n\n\ndef kfl_random_monotonic_initializer(shape,\n                                     scale,\n                                     monotonicities,\n                                     init_min=0.5,\n                                     init_max=1.5,\n                                     dtype=tf.float32,\n                                     seed=None):\n  \"\"\"Returns a uniformly random sampled monotonic weight tensor.\n\n  - The uniform random monotonic function will initilaize the lattice parameters\n    uniformly at random and make it such that the parameters are monotonically\n    increasing for each input.\n  - The random parameters will be sampled from `[init_min, init_max]`\n\n  Args:\n    shape: Shape of weights to initialize. Must be: `(1, lattice_sizes, units *\n      dims, num_terms)`.\n    scale: Scale variable of shape: `(units, num_terms)`.\n    monotonicities: None or list or tuple of length dims of elements of {0,1}\n      which represents monotonicity constraints per dimension. 1 stands for\n      increasing (non-decreasing in fact), 0 for no monotonicity constraints.\n    init_min: The lower bound on the range of initialized weights.\n    init_max: The upper bound on the range of initialized weights.\n    dtype: dtype\n    seed: A Python integer. Used to create a random seed for the distribution.\n\n  Returns:\n    Kronecker-Factored Lattice weights tensor of shape:\n    `(1, lattice_sizes, units * dims, num_terms)`.\n  \"\"\"\n  # Sample from the uniform distribution.\n  weights = tf.random.uniform(\n      shape, minval=init_min, maxval=init_max, dtype=dtype, seed=seed)\n  if utils.count_non_zeros(monotonicities) > 0:\n    # To sort, we must first reshape and unstack our weights.\n    dims = len(monotonicities)\n    _, lattice_sizes, units_times_dims, num_terms = shape\n    if units_times_dims % dims != 0:\n      raise ValueError(\n          \"len(monotonicities) is {}, which does not evenly divide shape[2].\"\n          \"len(monotonicities) should be equal to `dims`, and shape[2] \"\n          \"should be equal to units * dims.\".format(dims))\n    units = units_times_dims // dims\n    weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])\n    # Make all dimensions monotonically increasing with respect to the sign of\n    # scale.\n    direction = tf.expand_dims(tf.sign(scale), axis=1)\n    # Now we can unstack each dimension.\n    weights = tf.unstack(direction * weights, axis=3)\n    monotonic_weights = [\n        tf.sort(weight, axis=1) if monotonicity else weight\n        for monotonicity, weight in zip(monotonicities, weights)\n    ]\n    # Restack, reshape, and return weights\n    weights = tf.stack(monotonic_weights, axis=3)\n    weights = tf.reshape(direction * weights, shape)\n  return weights\n\n\ndef scale_initializer(units, num_terms, output_min, output_max):\n  \"\"\"Initializes scale depending on output_min and output_max.\n\n  If both output_min and output_max are set, scale is initialized to half their\n  difference, alternating signs for each term. If only output_min is set, scale\n  is initialized to 1 for each term. If only output_max is set, scale is\n  initialized to -1 for each term. Otherwise scale is initialized to alternate\n  between 1 and -1 for each term.\n\n  Args:\n    units: Output dimension of the layer. Each unit's scale will be initialized\n      identically.\n    num_terms: Number of independently trained submodels per unit, the outputs\n      of which are averaged to get the final output.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n\n  Returns:\n    Kronecker-Factored Lattice scale of shape: `(units, num_terms)`.\n  \"\"\"\n  if output_min is not None and output_max is None:\n    return np.ones([units, num_terms])\n  if output_min is None and output_max is not None:\n    return -np.ones([units, num_terms])\n  # Both or neither bounds are set, so we alternate sign.\n  signs = (np.arange(num_terms) % -2) * 2 + 1\n  scale = np.tile(signs, [units, 1])\n  if output_min is not None and output_max is not None:\n    scale = scale * ((output_max - output_min) / 2.0)\n  return scale\n\n\ndef bias_initializer(units, output_min, output_max, dtype=tf.float32):\n  \"\"\"Initializes bias depending on output_min and output_max.\n\n  If both output_min and output_max are set, bias is initialized to their\n  average. If only output_min is set, bias is initialized to output_min. If only\n  output_max is set, bias is initialized to output_max. Otherwise bias is\n  initialized to zeros.\n\n  Args:\n    units: Output dimension of the layer. Each of units bias will be initialized\n      identically.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n    dtype: dtype\n\n  Returns:\n    Kronecker-Factored Lattice bias of shape: `(units)`.\n  \"\"\"\n  if output_min is not None and output_max is not None:\n    return tf.constant(\n        (output_min + output_max) / 2.0, shape=[units], dtype=dtype)\n  elif output_min is not None:\n    return tf.constant(output_min, shape=[units], dtype=dtype)\n  elif output_max is not None:\n    # In this case, weights will be nonnegative and scale will be nonpositive so\n    # we add output_max to interpolation output to achieve proper bound.\n    return tf.constant(output_max, shape=[units], dtype=dtype)\n  else:\n    return tf.zeros(shape=[units], dtype=dtype)\n\n\ndef _approximately_project_monotonicity(weights, units, scale, monotonicities):\n  \"\"\"Approximately projects to strictly meet monotonicity constraints.\n\n  For more details, see _approximately_project_monotonicity in lattice_lib.py.\n\n  Args:\n    weights: Tensor with weights of shape `(1, lattice_sizes, units * dims,\n      num_terms)`.\n    units: Number of units per input dimension.\n    scale: Scale variable of shape: `(units, num_terms)`.\n    monotonicities: List or tuple of length dims of elements of {0,1} which\n      represents monotonicity constraints per dimension. 1 stands for increasing\n      (non-decreasing in fact), 0 for no monotonicity constraints.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n  # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms).\n  weights_shape = weights.get_shape().as_list()\n  _, lattice_sizes, units_times_dims, num_terms = weights_shape\n  assert units_times_dims % units == 0\n  dims = units_times_dims // units\n  weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])\n\n  # Extract the sign of scale to determine the projection direction.\n  direction = tf.expand_dims(tf.sign(scale), axis=1)\n\n  # TODO: optimize for case where all dims are monotonic and we won't\n  # need to unstack.\n  # Unstack our weights such that we have the weight for each dimension. We\n  # multiply by direction such that we always project the weights to be\n  # increasing.\n  weights = tf.unstack(direction * weights, axis=3)\n  projected = []\n  for weight, monotonicity in zip(weights, monotonicities):\n    if monotonicity:\n      # First we go forward to find the maximum projection.\n      max_projection = tf.unstack(weight, axis=1)\n      for i in range(1, len(max_projection)):\n        max_projection[i] = tf.maximum(max_projection[i], max_projection[i - 1])\n      # Find the halfway projection to find the minimum projection.\n      half_projection = (weight + tf.stack(max_projection, axis=1)) / 2.0\n      # Now we go backwards to find the minimum projection.\n      min_projection = tf.unstack(half_projection, axis=1)\n      for i in range(len(min_projection) - 2, -1, -1):\n        min_projection[i] = tf.minimum(min_projection[i], min_projection[i + 1])\n      # Restack our weight from the minimum projection.\n      weight = tf.stack(min_projection, axis=1)\n    # Add our projected weight to our running list.\n    projected.append(weight)\n  # Restack our final projected weights. We multiply by direction such that if\n  # direction is negative we end up with decreasing weights.\n  weights = direction * tf.stack(projected, axis=3)\n\n  # Reshape projected weights into original shape and return them.\n  weights = tf.reshape(weights, weights_shape)\n  return weights\n\n\ndef _approximately_project_bounds(weights, units, output_min, output_max):\n  \"\"\"Approximately projects to strictly meet bound constraints.\n\n  For more details, see _approximately_project_bounds in lattice_lib.py.\n\n  Args:\n    weights: Tensor with weights of shape `(1, lattice_sizes, units * dims,\n      num_terms)`.\n    units: Number of units per input dimension.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n  if output_min is None and output_max is None:\n    return weights\n\n  # We project by the dims'th root projection factor of the weights, ultimately\n  # projecting each term into the range [-1,1], but only if both output_min and\n  # output_max are specified. Otherwise, we restrict the weights to be\n  # nonnegative and the interpolation will do a final shift to respect the\n  # one-sided bound.\n  if output_min is not None and output_max is not None:\n    # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms).\n    weights_shape = weights.get_shape().as_list()\n    _, lattice_sizes, units_times_dims, num_terms = weights_shape\n    assert units_times_dims % units == 0\n    dims = units_times_dims // units\n    weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])\n    max_keypoint_values = tf.reduce_max(tf.abs(weights), axis=1, keepdims=True)\n    max_output_value = tf.reduce_prod(\n        max_keypoint_values, axis=3, keepdims=True)\n    full_projection_factor = tf.maximum(max_output_value, 1.0)\n    individual_projection_factor = tf.pow(full_projection_factor, 1.0 / dims)\n    weights = weights / individual_projection_factor\n    # We must reshape to get our final projected weights.\n    weights = tf.reshape(weights, weights_shape)\n  else:\n    weights = tf.maximum(weights, 0)\n\n  return weights\n\n\n# Note: this function must not depend on the result of projecting scale.\n# Currently this function depends on the sign of scale, but the scale projection\n# will not flip the sign of scale (only make it 0 in the worse case), which will\n# not cause any issues.\ndef finalize_weight_constraints(weights, units, scale, monotonicities,\n                                output_min, output_max):\n  \"\"\"Approximately projects weights to strictly satisfy all constraints.\n\n  This projeciton guarantees that constraints are strictly met, but it is not\n  an exact projection w.r.t. the L2 norm. The computational cost is\n  `O(num_monotonic_dims * num_lattice_weights)`.\n\n  See helper functions `_approximately_project_*` for details of the individual\n  projection algorithms for each set of constraints.\n\n  Args:\n    weights: Kronecker-Factored Lattice weights tensor of shape: `(1,\n      lattice_sizes, units * dims, num_terms)`.\n    units: Number of units per input dimension.\n    scale: Scale variable of shape: `(units, num_terms)`.\n    monotonicities: List or tuple of length dims of elements of {0,1} which\n      represents monotonicity constraints per dimension. 1 stands for increasing\n      (non-decreasing in fact), 0 for no monotonicity constraints.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n\n  Returns:\n    Projected weights tensor of same shape as `weights`.\n  \"\"\"\n  if utils.count_non_zeros(monotonicities) > 0:\n    # TODO: in the case of only one monotonic dimension, we only have to\n    # constrain the non-monotonic dimensions to be positive.\n    # There must be monotonicity constraints, so we need all nonnegative\n    # weights.\n    weights = tf.maximum(weights, 0)\n    weights = _approximately_project_monotonicity(\n        weights=weights,\n        units=units,\n        scale=scale,\n        monotonicities=monotonicities)\n\n  if output_min is not None or output_max is not None:\n    weights = _approximately_project_bounds(\n        weights=weights,\n        units=units,\n        output_min=output_min,\n        output_max=output_max)\n\n  return weights\n\n\n# Note: we cannot rely on the weights projection occuring always before or\n# always after the scale projection, so this function must not result in a\n# projection that would ultimately change the results of the weights projection.\n# Currently the weights projection depends on the sign of scale, so this\n# function does not change the sign (only makes scale 0 in the worst case),\n# which will not cause any issues.\ndef finalize_scale_constraints(scale, output_min, output_max):\n  \"\"\"Clips scale to strictly satisfy all constraints.\n\n  Args:\n    scale: Scale variable of shape: `(units, num_terms)`.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n\n  Returns:\n    Clipped scale tensor of same shape as `scale`.\n  \"\"\"\n  if output_min is not None and output_max is not None:\n    bound = (output_max - output_min) / 2.0\n    scale = tf.clip_by_value(scale, clip_value_min=-bound, clip_value_max=bound)\n  elif output_min is not None:\n    # In this case, we need scale to be nonnegative to properly shift by bias\n    # and satisfy the one-sided max bound.\n    scale = tf.maximum(scale, 0)\n  elif output_max is not None:\n    # In this case, we need scale to be nonpositive to properly mirror and shift\n    # by bias and satisfy the one-sided min bound.\n    scale = tf.minimum(scale, 0)\n  return scale\n\n\ndef verify_hyperparameters(lattice_sizes=None,\n                           units=None,\n                           num_terms=None,\n                           input_shape=None,\n                           monotonicities=None,\n                           output_min=None,\n                           output_max=None):\n  \"\"\"Verifies that all given hyperparameters are consistent.\n\n  This function does not inspect weights themselves. Only their shape. Use\n  `assert_constraints()` to assert actual weights against constraints.\n\n  See `tfl.layers.KroneckerFactoredLattice` class level comment for detailed\n  description of arguments.\n\n  Args:\n    lattice_sizes: Lattice size to check against.\n    units: Units hyperparameter of `KroneckerFactoredLattice` layer.\n    num_terms: Number of independently trained submodels hyperparameter of\n      `KroneckerFactoredLattice` layer.\n    input_shape: Shape of layer input. Useful only if `units` and/or\n      `monotonicities` is set.\n    monotonicities: Monotonicities hyperparameter of `KroneckerFactoredLattice`\n      layer. Useful only if `input_shape` is set.\n    output_min: Minimum output of `KroneckerFactoredLattice` layer.\n    output_max: Maximum output of `KroneckerFactoredLattice` layer.\n\n  Raises:\n    ValueError: If lattice_sizes < 2.\n    ValueError: If units < 1.\n    ValueError: If num_terms < 1.\n    ValueError: If len(monotonicities) does not match number of inputs.\n  \"\"\"\n  if lattice_sizes and lattice_sizes < 2:\n    raise ValueError(\"Lattice size must be at least 2. Given: %s\" %\n                     lattice_sizes)\n\n  if units and units < 1:\n    raise ValueError(\"Units must be at least 1. Given: %s\" % units)\n\n  if num_terms and num_terms < 1:\n    raise ValueError(\"Number of terms must be at least 1. Given: %s\" %\n                     num_terms)\n\n  # input_shape: (batch, ..., units, dims)\n  if input_shape:\n    # It also raises errors if monotonicities is specified incorrectly.\n    monotonicities = utils.canonicalize_monotonicities(\n        monotonicities, allow_decreasing=False)\n    # Extract shape to check units and dims to check monotonicity\n    if isinstance(input_shape, list):\n      dims = len(input_shape)\n      # Check monotonicity.\n      if monotonicities and len(monotonicities) != dims:\n        raise ValueError(\"If input is provided as list of tensors, their number\"\n                         \" must match monotonicities. 'input_list': %s, \"\n                         \"'monotonicities': %s\" % (input_shape, monotonicities))\n      shape = input_shape[0]\n    else:\n      dims = input_shape.as_list()[-1]\n      # Check monotonicity.\n      if monotonicities and len(monotonicities) != dims:\n        raise ValueError(\"Last dimension of input shape must have same number \"\n                         \"of elements as 'monotonicities'. 'input shape': %s, \"\n                         \"'monotonicities': %s\" % (input_shape, monotonicities))\n      shape = input_shape\n    if units and units > 1 and (len(shape) < 3 or shape[-2] != units):\n      raise ValueError(\"If 'units' > 1 then input shape of \"\n                       \"KroneckerFactoredLattice layer must have rank at least \"\n                       \"3 where the second from the last dimension is equal to \"\n                       \"'units'. 'units': %s, 'input_shape: %s\" %\n                       (units, input_shape))\n\n  if output_min is not None and output_max is not None:\n    if output_min >= output_max:\n      raise ValueError(\"'output_min' must be strictly less than 'output_max'. \"\n                       \"'output_min': %f, 'output_max': %f\" %\n                       (output_min, output_max))\n\n\ndef _assert_monotonicity_constraints(weights, units, scale, monotonicities,\n                                     eps):\n  \"\"\"Asserts that weights satisfy monotonicity constraints.\n\n  Args:\n    weights: `KroneckerFactoredLattice` weights tensor of shape: `(1,\n      lattice_sizes, units * dims, num_terms)`.\n    units: Number of units per input dimension.\n    scale: Scale variable of shape: `(units, num_terms)`.\n    monotonicities: Monotonicity constraints.\n    eps: Allowed constraints violation.\n\n  Returns:\n    List of monotonicity assertion ops in graph mode or directly executes\n    assertions in eager mode and returns a list of NoneType elements.\n  \"\"\"\n  monotonicity_asserts = []\n\n  # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms).\n  weights_shape = weights.get_shape().as_list()\n  _, lattice_sizes, units_times_dims, num_terms = weights_shape\n  assert units_times_dims % units == 0\n  dims = units_times_dims // units\n  weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])\n\n  # Extract the sign of scale to determine the assertion direction.\n  direction = tf.expand_dims(tf.sign(scale), axis=1)\n\n  # Unstack our weights given our extracted sign.\n  weights = tf.unstack(direction * weights, axis=3)\n  for i, (weight, monotonicity) in enumerate(zip(weights, monotonicities)):\n    if monotonicity:\n      keypoints = tf.unstack(weight, axis=1)\n      for j in range(1, len(keypoints)):\n        diff = tf.reduce_min(keypoints[j] - keypoints[j - 1])\n        monotonicity_asserts.append(\n            tf.Assert(\n                diff >= -eps,\n                data=[\n                    \"Monotonicity violation\", \"Feature index:\", i,\n                    \"Min monotonicity diff:\", diff, \"Upper layer number:\", j,\n                    \"Epsilon:\", eps, \"Keypoints:\", keypoints[j],\n                    keypoints[j - 1]\n                ]))\n\n  return monotonicity_asserts\n\n\ndef _assert_bound_constraints(weights, units, scale, output_min, output_max,\n                              eps):\n  \"\"\"Asserts that weights satisfy monotonicity constraints.\n\n  Args:\n    weights: `KroneckerFactoredLattice` weights tensor of shape: `(1,\n      lattice_sizes, units * dims, num_terms)`.\n    units: Number of units per input dimension.\n    scale: Scale variable of shape: `(units, num_terms)`.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n    eps: Allowed constraints violation.\n\n  Returns:\n    List of monotonicity assertion ops in graph mode or directly executes\n    assertions in eager mode and returns a list of NoneType elements.\n  \"\"\"\n  bound_asserts = []\n\n  # Recall that w.shape is (1, lattice_sizes, units * dims, num_terms).\n  weights_shape = weights.get_shape().as_list()\n  _, lattice_sizes, units_times_dims, num_terms = weights_shape\n  assert units_times_dims % units == 0\n  dims = units_times_dims // units\n  weights = tf.reshape(weights, [-1, lattice_sizes, units, dims, num_terms])\n\n  # If both bounds are specified, we must also have that the maximum output be\n  # between -1 and 1.\n  if output_min is not None and output_max is not None:\n    for term, term_weights in enumerate(tf.unstack(weights, axis=4)):\n      max_keypoint_values = tf.reduce_max(\n          tf.abs(term_weights), axis=1, keepdims=True)\n      max_output_values = tf.reduce_prod(\n          max_keypoint_values, axis=3, keepdims=True)\n      for unit, unit_max_output_value in enumerate(\n          tf.unstack(max_output_values, axis=2)):\n        diff = tf.squeeze(1 - unit_max_output_value)\n        bound_asserts.append(\n            tf.Assert(\n                diff >= -eps,\n                data=[\n                    \"Bound violation (max output greater than 1)\", \"Diff\", diff,\n                    \"Epsilon\", eps, \"Maximum output value\",\n                    unit_max_output_value, \"Term index\", term, \"Unit\", unit,\n                    \"Weights\", weights\n                ]))\n  else:\n    # If only one bound is specified, we must have that all of our weights are\n    # nonnegative at this point. There can be no allowed epsilon error here\n    # because of the effect of a negative value.\n    total_negative_weights = tf.reduce_sum(tf.cast(weights < 0, tf.int32))\n    bound_asserts.append(\n        tf.Assert(\n            total_negative_weights <= 0,\n            data=[\n                \"Bound violation (negative weights)\",\n                \"Number of negative weights\", total_negative_weights, \"Weights\",\n                weights\n            ]))\n\n  # If both bounds are specified, scale must be between\n  # -(output_max-output_min)/2 and (output_max-output_min)/2. If only output_min\n  # is specified, then scale must be nonnegative. If only output_max is\n  # specified, then scale must be nonpositive.\n  if output_min is not None and output_max is not None:\n    bound = (output_max - output_min) / 2.0\n    below_bound_scales = tf.reduce_sum(tf.cast(scale < -bound, tf.int32))\n    above_bound_scale = tf.reduce_sum(tf.cast(scale > bound, tf.int32))\n    bound_asserts.append(\n        tf.Assert(\n            below_bound_scales + above_bound_scale <= 0,\n            data=[\n                \"Bound violation (scale out of bounds)\", \"Bound\", bound,\n                \"Scale\", scale\n            ]))\n  elif output_min is not None:\n    negative_scales = tf.reduce_sum(tf.cast(scale < 0, tf.int32))\n    bound_asserts.append(\n        tf.Assert(\n            negative_scales <= 0,\n            data=[\n                \"Bound violation (only output_min specified with negative \"\n                \"scale values)\", \"Scale\", scale\n            ]))\n  elif output_max is not None:\n    positive_scales = tf.reduce_sum(tf.cast(scale > 0, tf.int32))\n    bound_asserts.append(\n        tf.Assert(\n            positive_scales <= 0,\n            data=[\n                \"Bound violation (only output_max specified with positive \"\n                \"scale values)\", \"Scale\", scale\n            ]))\n\n  return bound_asserts\n\n\ndef assert_constraints(weights,\n                       units,\n                       scale,\n                       monotonicities,\n                       output_min,\n                       output_max,\n                       eps=1e-6):\n  \"\"\"Asserts that weights satisfy constraints.\n\n  Args:\n    weights: `KroneckerFactoredLattice` weights tensor of shape: `(1,\n      lattice_sizes, units * dims, num_terms)`.\n    units: Number of units per input dimension.\n    scale: Scale variable of shape: `(units, num_terms)`.\n    monotonicities: Monotonicity constraints.\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n    eps: Allowed constraints violation.\n\n  Returns:\n    List of assertion ops in graph mode or directly executes assertions in eager\n    mode.\n  \"\"\"\n  asserts = []\n\n  if monotonicities:\n    monotonicity_asserts = _assert_monotonicity_constraints(\n        weights=weights,\n        units=units,\n        scale=scale,\n        monotonicities=monotonicities,\n        eps=eps)\n    asserts.extend(monotonicity_asserts)\n\n  if output_min is not None or output_max is not None:\n    bound_asserts = _assert_bound_constraints(\n        weights=weights,\n        units=units,\n        scale=scale,\n        output_min=output_min,\n        output_max=output_max,\n        eps=eps)\n    asserts.extend(bound_asserts)\n\n  return asserts\n"
  },
  {
    "path": "tensorflow_lattice/python/kronecker_factored_lattice_test.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for KroneckerFactoredLattice Layer.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n\nimport tempfile\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import kronecker_factored_lattice_layer as kfll\nfrom tensorflow_lattice.python import kronecker_factored_lattice_lib as kfl_lib\nfrom tensorflow_lattice.python import test_utils\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass KroneckerFactoredLatticeTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(KroneckerFactoredLatticeTest, self).setUp()\n    self.disable_all = False\n    self.disable_ensembles = False\n    self.loss_eps = 0.001\n    self.small_eps = 1e-6\n    self.seed = 42\n    keras.utils.set_random_seed(42)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  def _ScatterXUniformly(self, num_points, lattice_sizes, input_dims):\n    \"\"\"Deterministically generates num_point random points within lattice.\"\"\"\n    x = []\n    for _ in range(num_points):\n      point = [\n          np.random.random() * (lattice_sizes - 1.0) for _ in range(input_dims)\n      ]\n      x.append(np.asarray(point))\n    if input_dims == 1:\n      x.sort()\n    return x\n\n  def _ScatterXUniformlyExtendedRange(self, num_points, lattice_sizes,\n                                      input_dims):\n    \"\"\"Extends every dimension by 1.0 on both sides and generates points.\"\"\"\n    x = []\n    for _ in range(num_points):\n      point = [\n          np.random.random() * (lattice_sizes + 1.0) - 1.0\n          for _ in range(input_dims)\n      ]\n      x.append(np.asarray(point))\n    if input_dims == 1:\n      x.sort()\n    return x\n\n  def _SameValueForAllDims(self, num_points, lattice_sizes, input_dims):\n    \"\"\"Generates random point with same value for every dimension.\"\"\"\n    x = []\n    for _ in range(num_points):\n      rand = np.random.random() * (lattice_sizes - 1.0)\n      point = [rand] * input_dims\n      x.append(np.asarray(point))\n    if input_dims == 1:\n      x.sort()\n    return x\n\n  def _TwoDMeshGrid(self, num_points, lattice_sizes, input_dims):\n    \"\"\"Mesh grid for visualisation of 3-d surfaces via pyplot.\"\"\"\n    if input_dims != 2:\n      raise ValueError(\"2-d mesh grid is possible only for 2-d lattice. Lattice\"\n                       \" dimension given: %s\" % input_dims)\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points,\n        x_min=0.0,\n        y_min=0.0,\n        x_max=lattice_sizes - 1.0,\n        y_max=lattice_sizes - 1.0)\n\n  def _TwoDMeshGridExtendedRange(self, num_points, lattice_sizes, input_dims):\n    \"\"\"Mesh grid extended by 1.0 on every side.\"\"\"\n    if input_dims != 2:\n      raise ValueError(\"2-d mesh grid is possible only for 2-d lattice. Lattice\"\n                       \" dimension given: %s\" % input_dims)\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points,\n        x_min=-1.0,\n        y_min=-1.0,\n        x_max=lattice_sizes,\n        y_max=lattice_sizes)\n\n  def _Sin(self, x):\n    return math.sin(x[0])\n\n  def _SinPlusX(self, x):\n    return math.sin(x[0]) + x[0] / 3.0\n\n  def _SinPlusLargeX(self, x):\n    return math.sin(x[0]) + x[0]\n\n  def _SinPlusXNd(self, x):\n    return np.sum([math.sin(y) + y / 5.0 for y in x])\n\n  def _SinOfSum(self, x):\n    return math.sin(sum(x))\n\n  def _Max(self, x):\n    return np.amax(x)\n\n  def _ScaledSum(self, x):\n    result = 0.0\n    for y in x:\n      result += y / len(x)\n    return result\n\n  def _GetNonMonotonicInitializer(self, weights):\n    \"\"\"Tiles given weights along 'units' dimension.\"\"\"\n    dims = len(weights)\n\n    def Initializer(shape, dtype):\n      _, lattice_sizes, num_inputs, num_terms = shape\n      units = num_inputs // dims\n      # Create expanded weights, tile, reshape, return.\n      return tf.reshape(\n          tf.tile(\n              tf.constant(\n                  weights,\n                  shape=[1, lattice_sizes, 1, dims, num_terms],\n                  dtype=dtype),\n              multiples=[1, 1, units, 1, 1]), shape)\n\n    return Initializer\n\n  def _GetTrainingInputsAndLabels(self, config):\n    \"\"\"Generates training inputs and labels.\n\n    Args:\n      config: Dictionary with config for this unit test.\n\n    Returns:\n      Tuple `(training_inputs, training_labels)` where\n      `training_inputs` and `training_labels` are data for training.\n    \"\"\"\n    raw_training_inputs = config[\"x_generator\"](\n        num_points=config[\"num_training_records\"],\n        lattice_sizes=config[\"lattice_sizes\"],\n        input_dims=config[\"input_dims\"])\n\n    if isinstance(raw_training_inputs, tuple):\n      # This means that raw inputs are 2-d mesh grid. Convert them into list of\n      # 2-d points.\n      training_inputs = list(np.dstack(raw_training_inputs).reshape((-1, 2)))\n    else:\n      training_inputs = raw_training_inputs\n\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n    return training_inputs, training_labels\n\n  def _SetDefaults(self, config):\n    config.setdefault(\"units\", 1)\n    config.setdefault(\"num_terms\", 2)\n    config.setdefault(\"monotonicities\", None)\n    config.setdefault(\"output_min\", None)\n    config.setdefault(\"output_max\", None)\n    config.setdefault(\"signal_name\", \"TEST\")\n    config.setdefault(\"target_monotonicity_diff\", 0.0)\n    config.setdefault(\"lattice_index\", 0)\n    config.setdefault(\"scale_initializer\", \"scale_initializer\")\n\n    return config\n\n  def _TestEnsemble(self, config):\n    \"\"\"Verifies that 'units > 1' lattice produces same output as 'units==1'.\"\"\"\n    # Note that the initialization of the lattice must be the same across the\n    # units dimension (otherwise the loss will be different).\n    # We fix the random seed to make sure we get similar initialization.\n    if self.disable_ensembles:\n      return\n    config = dict(config)\n    config[\"num_training_epoch\"] = 3\n    config[\"kernel_initializer\"] = \"constant\"\n    losses = []\n    for units, lattice_index in [(1, 0), (3, 0), (3, 2)]:\n      config[\"units\"] = units\n      config[\"lattice_index\"] = lattice_index\n      keras.utils.set_random_seed(42)\n      losses.append(self._TrainModel(config))\n    self.assertAlmostEqual(min(losses), max(losses), delta=self.loss_eps)\n\n  def _TrainModel(self, config):\n    logging.info(\"Testing config:\")\n    logging.info(config)\n    config = self._SetDefaults(config)\n    self._ResetAllBackends()\n\n    training_inputs, training_labels = (\n        self._GetTrainingInputsAndLabels(config))\n\n    units = config[\"units\"]\n    input_dims = config[\"input_dims\"]\n    lattice_sizes = config[\"lattice_sizes\"]\n    if units > 1:\n      # In order to test multi 'units' lattice replecate inputs 'units' times\n      # and later use just one out of 'units' outputs in order to ensure that\n      # multi 'units' lattice trains exactly similar to single 'units' one.\n      training_inputs = [\n          np.tile(np.expand_dims(x, axis=0), reps=[units, 1])\n          for x in training_inputs\n      ]\n      input_shape = (units, input_dims)\n    else:\n      input_shape = (input_dims,)\n\n    keras_layer = kfll.KroneckerFactoredLattice(\n        lattice_sizes=lattice_sizes,\n        units=units,\n        num_terms=config[\"num_terms\"],\n        monotonicities=config[\"monotonicities\"],\n        output_min=config[\"output_min\"],\n        output_max=config[\"output_max\"],\n        kernel_initializer=config[\"kernel_initializer\"],\n        scale_initializer=config[\"scale_initializer\"],\n        input_shape=input_shape,\n        dtype=tf.float32)\n    model = keras.models.Sequential()\n    model.add(keras_layer)\n\n    # When we use multi-unit lattices, we only extract a single lattice for\n    # testing.\n    if units > 1:\n      lattice_index = config[\"lattice_index\"]\n      model.add(\n          keras.layers.Lambda(lambda x: x[:, lattice_index:lattice_index + 1]))\n\n    optimizer = config[\"optimizer\"](learning_rate=config[\"learning_rate\"])\n    model.compile(loss=keras.losses.mean_squared_error, optimizer=optimizer)\n\n    training_data = (training_inputs, training_labels)\n    loss = test_utils.run_training_loop(\n        config=config, training_data=training_data, keras_model=model\n    )\n\n    if tf.executing_eagerly():\n      tf.print(\"final weights: \", keras_layer.kernel)\n      tf.print(\"final scale: \", keras_layer.scale)\n      tf.print(\"final bias: \", keras_layer.bias)\n    assetion_ops = keras_layer.assert_constraints(\n        eps=-config[\"target_monotonicity_diff\"])\n    if not tf.executing_eagerly() and assetion_ops:\n      tf.compat.v1.keras.backend.get_session().run(assetion_ops)\n\n    return loss\n\n  def testMonotonicityOneD(self):\n    if self.disable_all:\n      return\n    monotonicities = [1]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 20,\n        \"input_dims\": 1,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinPlusX,\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.114794, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    monotonicities = [\"increasing\"]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 20,\n        \"input_dims\": 1,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": lambda x: -self._SinPlusX(x),\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 3.011028, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    monotonicities = [1]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 5,\n        \"input_dims\": 1,\n        \"num_terms\": 1,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinPlusLargeX,\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n        # Target function is strictly increasing.\n        \"target_monotonicity_diff\": 0.01,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000832, delta=self.loss_eps)\n\n  def testMonotonicityTwoD(self):\n    if self.disable_all:\n      return\n    monotonicities = [1, 1]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 21,\n        \"input_dims\": 2,\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.407444, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    monotonicities = [\"none\", \"increasing\"]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 21,\n        \"input_dims\": 2,\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.354508, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    monotonicities = [1, 0]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 21,\n        \"input_dims\": 2,\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.365634, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    monotonicities = [1, 1]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 2,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": lambda x: -self._ScaledSum(x),\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.054951, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testMonotonicity5d(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 5,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._ScaledSum,\n        \"monotonicities\": [1, 1, 1, 1, 1],\n        \"kernel_initializer\": keras.initializers.Constant(value=0.5),\n        # Function is strictly increasing everywhere, so request monotonicity\n        # diff to be strictly positive.\n        \"target_monotonicity_diff\": 0.08,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000524, delta=self.loss_eps)\n\n    monotonicities = [1, 1, 1, 1, 1]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 5,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": lambda x: -self._ScaledSum(x),\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.016635, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    monotonicities = [1, \"increasing\", 1, 1]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 3,\n        \"input_dims\": 4,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": monotonicities,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.398279, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([0, 1, 1],),\n      ([1, 0, 1],),\n      ([1, 1, 0],),\n  )\n  def testMonotonicityEquivalence(self, monotonicities):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": 3,\n        \"input_dims\": 3,\n        \"monotonicities\": monotonicities,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._SameValueForAllDims,\n        \"y_function\": self._SinOfSum,\n        \"kernel_initializer\": \"zeros\",\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.522080, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testMonotonicity10dAlmostMonotone(self):\n    if self.disable_all:\n      return\n    num_weights = 1024\n    weights = [1.0 * i / num_weights for i in range(num_weights)]\n    for _ in range(10):\n      i = int(np.random.random() * num_weights)\n      weights[i] = 0.0\n\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 10,\n        \"num_terms\": 128,\n        \"num_training_records\": 1000,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": test_utils.get_hypercube_interpolation_fn(weights),\n        \"monotonicities\": [1] * 10,\n        \"kernel_initializer\": \"zeros\",\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.027578, delta=self.loss_eps)\n\n    config[\"monotonicities\"] = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.027578, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testMonotonicity10dSinOfSum(self):\n    if self.disable_all:\n      return\n    monotonicities = [1] * 10\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 10,\n        \"num_training_records\": 1000,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": [1] * 10,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.158914, delta=self.loss_eps)\n\n    monotonicities = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config[\"monotonicities\"] = monotonicities\n    config[\"kernel_initializer\"] = kernel_initializer\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.196240, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      # Custom TFL initializer:\n      (\"kfl_random_monotonic_initializer\", 2.024714),\n      # Standard Keras initializer:\n      (keras.initializers.Constant(value=1.5), 2.140740),\n      # Standard Keras initializer specified as string constant:\n      (\"zeros\", 2.140740),\n  )\n  def testInitializerType(self, initializer, expected_loss):\n    if self.disable_all:\n      return\n    if initializer == \"kfl_random_monotonic_initializer\":\n      initializer = kfll.KFLRandomMonotonicInitializer(\n          monotonicities=None, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 3,\n        \"input_dims\": 2,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._Max,\n        \"kernel_initializer\": initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testAssertMonotonicity(self):\n    if self.disable_all:\n      return\n    # Specify non monotonic initializer and do 0 training iterations so no\n    # projections are being executed.\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 2,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._ScaledSum,\n        \"monotonicities\": [0, 0],\n        \"kernel_initializer\": self._GetNonMonotonicInitializer(\n            weights=[\n                [[4.0, 3.0], [4.0, 3.0]],\n                [[2.0, 1.0], [2.0, 1.0]]\n            ])\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 4.458333, delta=self.loss_eps)\n\n    for monotonicity in [[0, 1], [1, 0], [1, 1]]:\n      for units in [1, 3]:\n        config[\"monotonicities\"] = monotonicity\n        config[\"units\"] = units\n        with self.assertRaises(tf.errors.InvalidArgumentError):\n          self._TrainModel(config)\n\n  @parameterized.parameters(\n      (\n          -1,\n          1,\n          kfll.KFLRandomMonotonicInitializer(\n              monotonicities=None, init_min=-10, init_max=10\n          ),\n          \"scale_initializer\",\n      ),\n      (\n          None,\n          1,\n          kfll.KFLRandomMonotonicInitializer(\n              monotonicities=None, init_min=-10, init_max=10\n          ),\n          \"scale_initializer\",\n      ),\n      (\n          -1,\n          None,\n          kfll.KFLRandomMonotonicInitializer(\n              monotonicities=None, init_min=-10, init_max=10\n          ),\n          \"scale_initializer\",\n      ),\n      (\n          -1,\n          1,\n          \"kfl_random_monotonic_initializer\",\n          keras.initializers.Constant(value=-100),\n      ),\n      (\n          None,\n          1,\n          \"kfl_random_monotonic_initializer\",\n          keras.initializers.Constant(value=100),\n      ),\n      (\n          -1,\n          None,\n          \"kfl_random_monotonic_initializer\",\n          keras.initializers.Constant(value=-100),\n      ),\n  )\n  def testAssertBounds(self, output_min, output_max, kernel_initializer,\n                       scale_initializer):\n    if self.disable_all:\n      return\n    # Specify random initializer that ensures initial output can be out of\n    # bounds and do 0 training iterations so no projections are executed.\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 2,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._ScaledSum,\n        \"monotonicities\": [0, 0],\n        \"output_min\": output_min,\n        \"output_max\": output_max,\n        \"kernel_initializer\": kernel_initializer,\n        \"scale_initializer\": scale_initializer,\n    }\n    with self.assertRaises(tf.errors.InvalidArgumentError):\n      self._TrainModel(config)\n\n  @parameterized.parameters(\n      (2, 1, -3, -1, 4.82327),\n      (2, 2, 0, 1, 0.163572),\n      (1, 2, -5, 5, 0.011432),\n      (1, 10, -1, 1, 0.6307245),\n      (1, 3, None, None, 0.012286),\n      (1, 2, None, 5, 0.011590),\n      (3, 3, 0, None, 0.011679),\n      (4, 2, None, -2, 9.9507179),\n  )\n  def testOutputBounds(self, units, input_dims, output_min, output_max,\n                       expected_loss):\n    if self.disable_all:\n      return\n    monotonicities = [1] * input_dims\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 4,\n        \"units\": units,\n        \"input_dims\": input_dims,\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinPlusX,\n        \"monotonicities\": monotonicities,\n        \"output_min\": output_min,\n        \"output_max\": output_max,\n        \"kernel_initializer\": kernel_initializer,\n        # This is the epsilon error allowed when asserting constraints,\n        # including bounds. We include this to ensure that the bound constraint\n        # assertions do not fail due to numerical errors.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (2, 1, 3, 2, -1, 1),\n      (2, 2, 2, 1, None, 1),\n      (2, 1, 3, 4, -1, None),\n      (3, 3, 4, 2, -50, 2),\n      (4, 4, 2, 4, -1.5, 2.3),\n      (2, 2, 2, 2, None, None),\n  )\n  # Note: dims must be at least 1\n  def testConstraints(self, lattice_sizes, units, dims, num_terms, output_min,\n                      output_max):\n    if self.disable_all:\n      return\n    # Run our test for 100 iterations to minimize the chance we pass by chance.\n    for _ in range(100):\n      # Create 100 random inputs that are frozen in all but the increasing\n      # dimension, which increases uniformly from 0 to lattice_sizes-1.\n      batch_size = 100\n      random_vals = [\n          np.random.uniform(0, lattice_sizes - 1) for _ in range(dims - 1)\n      ]\n      increasing_dim = np.random.randint(0, dims)\n      step_size = (lattice_sizes - 1) / batch_size\n      values = [\n          np.roll([0.0 + (i * step_size)] + random_vals, increasing_dim)\n          for i in range(batch_size)\n      ]\n      if units > 1:\n        values = [[value] * units for value in values]\n        shape = [batch_size, units, dims]\n      else:\n        shape = [batch_size, dims]\n      inputs = tf.constant(values, dtype=tf.float32, shape=shape)\n\n      # Create our weights, constraint them, and evaluate our function on our\n      # constructed inputs.\n      init_min = -1.5 if output_min is None else output_min\n      init_max = 1.5 if output_max is None else output_max\n\n      # Offset the initiailization bounds to increase likelihood of breaking\n      # constraints.\n      offset = 100\n      kernel = tf.random.uniform([1, lattice_sizes, units * dims, num_terms],\n                                 minval=init_min - offset,\n                                 maxval=init_max + offset)\n      scale = tf.random.uniform([units, num_terms],\n                                minval=init_min,\n                                maxval=init_max)\n      bias = kfl_lib.bias_initializer(\n          units, output_min, output_max, dtype=tf.float32)\n\n      scale_constraint = kfll.ScaleConstraints(output_min, output_max)\n      constrained_scale = scale_constraint(scale)\n\n      monotonicities = [np.random.randint(0, 2) for _ in range(dims)]\n      monotonicities[increasing_dim] = 1\n      kernel_constraint = kfll.KroneckerFactoredLatticeConstraints(\n          units, constrained_scale, monotonicities, output_min, output_max)\n      constrained_kernel = kernel_constraint(kernel)\n\n      outputs = kfl_lib.evaluate_with_hypercube_interpolation(\n          inputs=inputs,\n          scale=constrained_scale,\n          bias=bias,\n          kernel=constrained_kernel,\n          units=units,\n          num_terms=num_terms,\n          lattice_sizes=lattice_sizes,\n          clip_inputs=True)\n\n      # Check that outputs are inside our bounds\n      min_check = float(\"-inf\") if output_min is None else output_min\n      self.assertEqual(tf.reduce_sum(tf.cast(outputs < min_check, tf.int32)), 0)\n      max_check = float(\"+inf\") if output_max is None else output_max\n      self.assertEqual(tf.reduce_sum(tf.cast(outputs > max_check, tf.int32)), 0)\n      # Check that we satisfy monotonicity constraints. Note that by\n      # construction the outputs should already be in sorted order.\n      sorted_outputs = tf.sort(outputs, axis=0)\n      # We use close equality instead of strict equality because of numerical\n      # errors that result in nearly identical arrays failing a strict check\n      # after sorting.\n      self.assertAllClose(outputs, sorted_outputs, rtol=1e-6, atol=1e-6)\n\n  def testInputOutOfBounds(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": 6,\n        \"input_dims\": 1,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformlyExtendedRange,\n        \"y_function\": self._Sin,\n        \"kernel_initializer\": keras.initializers.Zeros(),\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.028617, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=None, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 2,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGridExtendedRange,\n        \"y_function\": self._SinOfSum,\n        \"kernel_initializer\": kernel_initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.323999, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testHighDimensionsStressTest(self):\n    if self.disable_all:\n      return\n    monotonicities = [0] * 16\n    monotonicities[3], monotonicities[4], monotonicities[10] = (1, 1, 1)\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=monotonicities, seed=self.seed)\n    config = {\n        \"lattice_sizes\": 2,\n        \"input_dims\": 16,\n        \"num_terms\": 128,\n        \"units\": 2,\n        \"monotonicities\": monotonicities,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 3,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"kernel_initializer\": kernel_initializer,\n        \"target_monotonicity_diff\": -1e-5,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.30680, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      (2, 5, 2, 57),\n      (2, 6, 4, 57),\n      (2, 9, 2, 57),\n      (3, 5, 4, 63),\n      (3, 9, 2, 63),\n  )\n  def testGraphSize(self, lattice_sizes, input_dims, num_terms,\n                    expected_graph_size):\n    # If this test failed then you modified core lattice interpolation logic in\n    # a way which increases number of ops in the graph. Or maybe Keras team\n    # changed something under the hood. Please ensure that this increase is\n    # unavoidable and try to minimize it.\n    if self.disable_all:\n      return\n    tf.compat.v1.disable_eager_execution()\n    tf.compat.v1.reset_default_graph()\n\n    layer = kfll.KroneckerFactoredLattice(\n        lattice_sizes=lattice_sizes, num_terms=num_terms)\n    input_tensor = tf.ones(shape=(1, input_dims))\n    layer(input_tensor)\n    graph_size = len(tf.compat.v1.get_default_graph().as_graph_def().node)\n\n    self.assertLessEqual(graph_size, expected_graph_size)\n\n  @parameterized.parameters(\n      (\"random_uniform\", keras.initializers.RandomUniform),\n      (\"kfl_random_monotonic_initializer\", kfll.KFLRandomMonotonicInitializer),\n  )\n  def testCreateKernelInitializer(self, kernel_initializer_id, expected_type):\n    self.assertEqual(\n        expected_type,\n        type(\n            kfll.create_kernel_initializer(\n                kernel_initializer_id,\n                monotonicities=None,\n                output_min=None,\n                output_max=None)))\n\n  # We test that the scale variable attribute of our KroneckerFactoredLattice\n  # is the same object as the scale contained in the constraint on the kernel,\n  # both before and after save/load. We test this because we must make sure that\n  # any updates to the scale variable (before/after save/load) are consistent\n  # across all uses of the object.\n  def testSavingLoadingScale(self):\n    # Create simple x --> x^2 dataset.\n    train_data = [[[float(x)], float(x)**2] for x in range(100)]\n    train_x, train_y = zip(*train_data)\n    train_x, train_y = np.array(train_x), np.array(train_y)\n    # Construct simple single lattice model. Must have monotonicities specified\n    # or constraint will be None.\n    keras_layer = kfll.KroneckerFactoredLattice(\n        lattice_sizes=2, monotonicities=[1])\n    model = keras.models.Sequential()\n    model.add(keras_layer)\n    # Compile and fit the model.\n    model.compile(\n        loss=\"mse\", optimizer=keras.optimizers.Adam(learning_rate=0.1))\n    model.fit(train_x, train_y)\n    # Extract scale from layer and constraint before save.\n    layer_scale = keras_layer.scale\n    constraint_scale = keras_layer.kernel.constraint.scale\n    self.assertIs(layer_scale, constraint_scale)\n    # Save and load the model.\n    with tempfile.NamedTemporaryFile(suffix=\".h5\") as f:\n      keras.models.save_model(model, f.name)\n      loaded_model = keras.models.load_model(\n          f.name,\n          custom_objects={\n              \"KroneckerFactoredLattice\":\n                  kfll.KroneckerFactoredLattice,\n              \"KroneckerFactoredLatticeConstraint\":\n                  kfll.KroneckerFactoredLatticeConstraints,\n              \"KFLRandomMonotonicInitializer\":\n                  kfll.KFLRandomMonotonicInitializer,\n              \"ScaleInitializer\":\n                  kfll.ScaleInitializer,\n              \"ScaleConstraints\":\n                  kfll.ScaleConstraints,\n              \"BiasInitializer\":\n                  kfll.BiasInitializer,\n          })\n    # Extract loaded layer.\n    loaded_keras_layer = loaded_model.layers[0]\n    # Extract scale from layer and constraint after load.\n    loaded_layer_scale = loaded_keras_layer.scale\n    loaded_constraint_scale = loaded_keras_layer.kernel.constraint.scale\n    self.assertIs(loaded_layer_scale, loaded_constraint_scale)\n\n  @parameterized.parameters(\n      (1, 3, 1),\n      (1, 3, 2),\n      (3, 7, 3),\n  )\n  def testOutputShapeForDifferentInputTypes(self, batch_size, dims, units):\n    expected_output_shape = (batch_size, units)\n    # Create KFL Layer instance.\n    kfl_layer = kfll.KroneckerFactoredLattice(lattice_sizes=2, units=units)\n    # Input (batch_size, dims) or (batch_size, units, dims)\n    if units == 1:\n      example = [float(i) for i in range(dims)]\n      examples = [example for _ in range(batch_size)]\n    else:\n      example = [[float(i) for i in range(dims)] for _ in range(units)]\n      examples = [example for _ in range(batch_size)]\n    inputs = tf.constant(examples)\n    outputs = kfl_layer(inputs)\n    self.assertEqual(outputs.shape, expected_output_shape)\n    # Input length-dims list of (batch_size, 1) or (batch_size, units, 1)\n    example = tf.constant(\n        [[float(i) if units == 1 else [float(i)]\n          for i in range(units)]\n         for _ in range(batch_size)])\n    list_inputs = [example for _ in range(dims)]\n    list_outputs = kfl_layer(list_inputs)\n    self.assertEqual(list_outputs.shape, expected_output_shape)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/lattice_layer.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Lattice layer with monotonicity, unimodality, trust and bound constraints.\n\nKeras implementation of tensorflow lattice layer. This layer takes one or more\nd-dimensional input(s) and combines them using a lattice function, satisfying\nmonotonicity, unimodality, trust and bound constraints if specified.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\nfrom . import lattice_lib\nfrom . import utils\n\nLATTICE_KERNEL_NAME = \"lattice_kernel\"\nLATTICE_SIZES_NAME = \"lattice_sizes\"\n\n\nclass Lattice(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Lattice layer.\n\n  Layer performs interpolation using one of `units` d-dimensional lattices with\n  arbitrary number of keypoints per dimension. There are trainable weights\n  associated with lattice vertices. Input to this layer is considered to be a\n  d-dimensional point within the lattice. If point coincides with one of the\n  lattice vertex then interpolation result for this point is equal to weight\n  associated with that vertex. Otherwise, all surrounding vertices contribute to\n  the interpolation result inversely proportional to the distance from them.\n\n  For example lattice sizes: [2, 3] produce following lattice:\n\n  ```\n  o---o---o\n  |   |   |\n  o---o---o\n  ```\n\n  First coordinate of input tensor must be within [0, 1], and the second within\n  [0, 2]. If coordinates are outside of this range they will be clipped into it.\n\n  There are several types of constraints on the shape of the learned function\n  that are either 1 or 2 dimensional:\n\n  ![Shape constraint visual example images](https://www.tensorflow.org/lattice/images/2D_shape_constraints_picture_color.png)\n\n  * **Monotonicity:** constrains the function to be increasing in the\n    corresponding dimension. To achieve decreasing monotonicity, either pass the\n    inputs through a `tfl.layers.PWLCalibration` with `decreasing` monotonicity,\n    or manually reverse the inputs as `lattice_size - 1 - inputs`.\n  * **Unimodality:** constrains the function to be unimodal in that dimension\n    with minimum being in the center lattice vertex of that dimension. Single\n    dimension can not be constrained to be both monotonic and unimodal.\n    Unimodal dimensions must have at least 3 lattice vertices.\n  * **Edgeworth Trust:** constrains the function to be more responsive to a main\n    feature as a secondary conditional feature increases or decreases. For\n    example, we may want the model to rely more on average rating (main\n    feature) when the number of reviews (conditional feature) is high. In\n    particular, the constraint guarantees that a given change in the main\n    feature's value will change the model output by more when a secondary\n    feature indicates higher trust in the main feature. Note that the\n    constraint only works when the model is monotonic in the main feature.\n  * **Trapezoid Trust:** conceptually similar to edgeworth trust, but this\n    constraint guarantees that the range of possible outputs along the main\n    feature dimension, when a conditional feature indicates low trust, is a\n    *subset* of the range of outputs when a conditional feature indicates high\n    trust. When lattices have 2 vertices in each constrained dimension, this\n    implies edgeworth trust (which only constrains the size of the relevant\n    ranges). With more than 2 lattice vertices per dimension, the two\n    constraints diverge and are not necessarily 'weaker' or 'stronger' than\n    each other - edgeworth trust acts throughout the lattice interior on delta\n    shifts in the main feature, while trapezoid trust only acts on the min and\n    max extremes of the main feature, constraining the overall range of\n    outputs across the domain of the main feature. The two types of trust\n    constraints can be applied jointly.\n  * **Monotonic Dominance:** constrains the function to require the effect\n    (slope) in the direction of the *dominant* dimension to be greater than that\n    of the *weak* dimension for any point in the lattice. Both dominant and weak\n    dimensions must be monotonic. Note that this constraint might not be\n    strictly satisified at the end of training. In such cases, increase the\n    number of projection iterations.\n  * **Range Dominance:** constraints the function to require the range of\n    possible outputs to be greater than if one varies the *dominant* dimension\n    than if one varies the *weak* dimension for any point. Both dominant and\n    weak dimensions must be monotonic. Note that this constraint might not be\n    strictly satisified at the end of training. In such cases, increase the\n    number of projection iterations.\n  * **Joint Monotonicity:** constrains the function to be monotonic along a\n    diagonal direction of a two dimensional subspace when all other dimensions\n    are fixed. For example, if our function is scoring the profit given *A*\n    hotel guests and *B* hotel beds, it may be wrong to constrain the profit to\n    be increasing in either hotel guests or hotel beds in-dependently, but along\n    the diagonal (+ 1 guest and +1 bed), the profit should be monotonic. Note\n    that this constraint might not be strictly satisified at the end of\n    training. In such cases, increase the number of projection iterations.\n\n  There are upper and lower bound constraints on the output.\n\n  All units share the same layer configuration, but each has their separate set\n  of trained parameters.\n\n  Input shape:\n    - if `units == 1`: tensor of shape: `(batch_size, ..., len(lattice_sizes))`\n      or list of `len(lattice_sizes)` tensors of same shape:\n      `(batch_size, ..., 1)`\n    - if `units > 1`: tensor of shape:\n      `(batch_size, ..., units, len(lattice_sizes))` or list of\n      `len(lattice_sizes)` tensors of same shape: `(batch_size, ..., units, 1)`\n\n    A typical shape is: `(batch_size, len(lattice_sizes))`\n\n  Output shape:\n    Tensor of shape: `(batch_size, ..., units)`\n\n  Attributes:\n    - All `__init__` arguments.\n    kernel: weights of the lattice.\n\n  Example:\n\n  ```python\n  lattice = tfl.layers.Lattice(\n      # Number of vertices along each dimension.\n      lattice_sizes=[2, 2, 3, 4, 2, 2, 3],\n      # You can specify monotonicity constraints.\n      monotonicities=['increasing', 'none', 'increasing', 'increasing',\n                      'increasing', 'increasing', 'increasing'],\n      # You can specify trust constraints between pairs of features. Here we\n      # constrain the function to be more responsive to a main feature (index 4)\n      # as a secondary conditional feature (index 3) increases (positive\n      # direction).\n      edgeworth_trusts=(4, 3, 'positive'),\n      # Output can be bounded.\n      output_min=0.0,\n      output_max=1.0)\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               lattice_sizes,\n               units=1,\n               monotonicities=None,\n               unimodalities=None,\n               edgeworth_trusts=None,\n               trapezoid_trusts=None,\n               monotonic_dominances=None,\n               range_dominances=None,\n               joint_monotonicities=None,\n               joint_unimodalities=None,\n               output_min=None,\n               output_max=None,\n               num_projection_iterations=10,\n               monotonic_at_every_step=True,\n               clip_inputs=True,\n               interpolation=\"hypercube\",\n               kernel_initializer=\"random_uniform_or_linear_initializer\",\n               kernel_regularizer=None,\n               **kwargs):\n    # pyformat: disable\n    \"\"\"Initializes an instance of `Lattice`.\n\n    Args:\n      lattice_sizes: List or tuple of length d of integers which represents\n        number of lattice vertices per dimension (minimum is 2). Second\n        dimension of input shape must match the number of elements in lattice\n        sizes.\n      units: Output dimension of the layer. See class comments for details.\n      monotonicities: None or list or tuple of same length as lattice_sizes of\n        {'none', 'increasing', 0, 1} which specifies if the model output should\n        be monotonic in corresponding feature, using 'increasing' or 1 to\n        indicate increasing monotonicity and 'none' or 0 to indicate no\n        monotonicity constraints.\n      unimodalities: None or list or tuple of same length as lattice_sizes of\n        {'none', 'valley', 'peak', 0, 1, -1} which specifies if the model output\n        should be unimodal in corresponding feature, using 'valley' or 1 to\n        indicate that function first decreases then increases, using 'peak' or\n        -1 to indicate that funciton first increases then decreases, using\n        'none' or 0 to indicate no unimodality constraints.\n      edgeworth_trusts: None or three-element tuple or iterable of three-element\n        tuples. First element is the index of the main (monotonic) feature.\n        Second element is the index of the conditional feature. Third element is\n        the direction of trust: 'positive' or 1 if higher values of the\n        conditional feature should increase trust in the main feature and\n        'negative' or -1 otherwise.\n      trapezoid_trusts: None or three-element tuple or iterable of three-element\n        tuples. First element is the index of the main (monotonic) feature.\n        Second element is the index of the conditional feature. Third element is\n        the direction of trust: 'positive' or 1 if higher values of the\n        conditional feature should increase trust in the main feature and\n        'negative' or -1 otherwise.\n      monotonic_dominances: None or two-element tuple or iterable of two-element\n        tuples. First element is the index of the dominant feature. Second\n        element is the index of the weak feature.\n      range_dominances: None or two-element tuple or iterable of two-element\n        tuples. First element is the index of the dominant feature. Second\n        element is the index of the weak feature.\n      joint_monotonicities: None or two-element tuple or iterable of two-element\n        tuples which represents indices of two features requiring joint\n        monotonicity.\n      joint_unimodalities: None or tuple or iterable of tuples. Each tuple\n        contains 2 elements: iterable of indices of single group of jointly\n        unimodal features followed by string 'valley' or 'peak', using 'valley'\n        to indicate that function first decreases then increases, using 'peak'\n        to indicate that funciton first increases then decreases. For example:\n        ([0, 3, 4], 'valley').\n      output_min: None or lower bound of the output.\n      output_max: None or upper bound of the output.\n      num_projection_iterations: Number of iterations of Dykstra projections\n        algorithm. Projection updates will be closer to a true projection (with\n        respect to the L2 norm) with higher number of iterations. Increasing\n        this number has diminishing return on projection precsion. Infinite\n        number of iterations would yield perfect projection. Increasing this\n        number might slightly improve convergence by cost of slightly increasing\n        running time. Most likely you want this number to be proportional to\n        number of lattice vertices in largest constrained dimension.\n      monotonic_at_every_step: Whether to strictly enforce monotonicity and\n        trust constraints after every gradient update by applying a final\n        imprecise projection. Setting this parameter to True together with small\n        num_projection_iterations parameter is likely to hurt convergence.\n      clip_inputs: If inputs should be clipped to the input range of the\n        lattice.\n      interpolation: One of 'hypercube' or 'simplex' interpolation. For a\n        d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas\n        'simplex' uses d+1 parameters and thus scales better. For details see\n        `tfl.lattice_lib.evaluate_with_simplex_interpolation` and\n        `tfl.lattice_lib.evaluate_with_hypercube_interpolation`.\n      kernel_initializer: None or one of:\n        - `'linear_initializer'`: initialize parameters to form a linear\n          function with positive and equal coefficients for monotonic dimensions\n          and 0.0 coefficients for other dimensions. Linear function is such\n          that minimum possible output is equal to output_min and maximum\n          possible output is equal to output_max. See\n          `tfl.lattice_layer.LinearInitializer` class docstring for more\n          details.\n        - `'random_monotonic_initializer'`: initialize parameters uniformly at\n          random such that all parameters are monotonically increasing for each\n          input. Parameters will be sampled uniformly at random from the range\n          `[output_min, output_max]`. See\n          `tfl.lattice_layer.RandomMonotonicInitializer` class docstring for\n          more details.\n        - `random_uniform_or_linear_initializer`: if the lattice has a single\n          joint unimodality constraint group encompassing all features then use\n          the Keras 'random_uniform' initializer; otherwise, use TFL's\n          'linear_initializer'.\n        - Any Keras initializer object.\n      kernel_regularizer: None or a single element or a list of following:\n        - Tuple `('torsion', l1, l2)` where l1 and l2 represent corresponding\n          regularization amount for graph Torsion regularizer. l1 and l2 can\n          either be single floats or lists of floats to specify different\n          regularization amount for every dimension.\n        - Tuple `('laplacian', l1, l2)` where l1 and l2 represent corresponding\n          regularization amount for graph Laplacian regularizer. l1 and l2 can\n          either be single floats or lists of floats to specify different\n          regularization amount for every dimension.\n        - Any Keras regularizer object.\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: If layer hyperparameters are invalid.\n    \"\"\"\n    # pyformat: enable\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        monotonicities=monotonicities,\n        unimodalities=unimodalities,\n        interpolation=interpolation)\n    super(Lattice, self).__init__(**kwargs)\n\n    self.lattice_sizes = lattice_sizes\n    self.units = units\n    self.monotonicities = monotonicities\n    self.unimodalities = unimodalities\n    # Check if inputs are a single tuple of ints (vs an iterable of tuples)\n    if (isinstance(edgeworth_trusts, tuple) and\n        isinstance(edgeworth_trusts[0], int)):\n      self.edgeworth_trusts = [edgeworth_trusts]\n    else:\n      self.edgeworth_trusts = edgeworth_trusts\n    if (isinstance(trapezoid_trusts, tuple) and\n        isinstance(trapezoid_trusts[0], int)):\n      self.trapezoid_trusts = [trapezoid_trusts]\n    else:\n      self.trapezoid_trusts = trapezoid_trusts\n    if (isinstance(monotonic_dominances, tuple) and\n        isinstance(monotonic_dominances[0], int)):\n      self.monotonic_dominances = [monotonic_dominances]\n    else:\n      self.monotonic_dominances = monotonic_dominances\n    if (isinstance(range_dominances, tuple) and\n        isinstance(range_dominances[0], int)):\n      self.range_dominances = [range_dominances]\n    else:\n      self.range_dominances = range_dominances\n    if (isinstance(joint_monotonicities, tuple) and\n        isinstance(joint_monotonicities[0], int)):\n      self.joint_monotonicities = [joint_monotonicities]\n    else:\n      self.joint_monotonicities = joint_monotonicities\n    if (isinstance(joint_unimodalities, tuple) and\n        len(joint_unimodalities) == 2 and\n        isinstance(joint_unimodalities[1], six.string_types)):\n      self.joint_unimodalities = [joint_unimodalities]\n    else:\n      self.joint_unimodalities = joint_unimodalities\n    self.output_min = output_min\n    self.output_max = output_max\n    self.num_projection_iterations = num_projection_iterations\n    self.monotonic_at_every_step = monotonic_at_every_step\n    self.clip_inputs = clip_inputs\n    self.interpolation = interpolation\n\n    self.kernel_initializer = create_kernel_initializer(\n        kernel_initializer, self.lattice_sizes, self.monotonicities,\n        self.output_min, self.output_max, self.unimodalities,\n        self.joint_unimodalities)\n\n    self.kernel_regularizer = []\n    if kernel_regularizer:\n      if (callable(kernel_regularizer) or\n          (isinstance(kernel_regularizer, tuple) and\n           isinstance(kernel_regularizer[0], six.string_types))):\n        kernel_regularizer = [kernel_regularizer]\n\n      for regularizer in kernel_regularizer:\n        if isinstance(regularizer, tuple):\n          (name, l1, l2) = regularizer\n          if name.lower() == \"torsion\":\n            self.kernel_regularizer.append(\n                TorsionRegularizer(\n                    lattice_sizes=self.lattice_sizes, l1=l1, l2=l2))\n          elif name.lower() == \"laplacian\":\n            self.kernel_regularizer.append(\n                LaplacianRegularizer(\n                    lattice_sizes=self.lattice_sizes, l1=l1, l2=l2))\n          else:\n            raise ValueError(\"Unknown custom lattice regularizer: %s\" %\n                             regularizer)\n        else:\n          # This is needed for Keras deserialization logic to be aware of our\n          # custom objects.\n          with keras.utils.custom_object_scope({\n              \"TorsionRegularizer\": TorsionRegularizer,\n              \"LaplacianRegularizer\": LaplacianRegularizer,\n          }):\n            self.kernel_regularizer.append(keras.regularizers.get(regularizer))\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=self.lattice_sizes,\n        units=self.units,\n        input_shape=input_shape)\n    constraints = LatticeConstraints(\n        lattice_sizes=self.lattice_sizes,\n        monotonicities=self.monotonicities,\n        unimodalities=self.unimodalities,\n        edgeworth_trusts=self.edgeworth_trusts,\n        trapezoid_trusts=self.trapezoid_trusts,\n        monotonic_dominances=self.monotonic_dominances,\n        range_dominances=self.range_dominances,\n        joint_monotonicities=self.joint_monotonicities,\n        joint_unimodalities=self.joint_unimodalities,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        num_projection_iterations=self.num_projection_iterations,\n        enforce_strict_monotonicity=self.monotonic_at_every_step)\n\n    if not self.kernel_regularizer:\n      kernel_reg = None\n    elif len(self.kernel_regularizer) == 1:\n      kernel_reg = self.kernel_regularizer[0]\n    else:\n      # Keras interface assumes only one regularizer, so summ all regularization\n      # losses which we have.\n      kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer])\n\n    num_weights = 1\n    for dim_size in self.lattice_sizes:\n      num_weights *= dim_size\n    self.kernel = self.add_weight(\n        LATTICE_KERNEL_NAME,\n        shape=[num_weights, self.units],\n        initializer=self.kernel_initializer,\n        regularizer=kernel_reg,\n        constraint=constraints,\n        dtype=self.dtype)\n\n    if self.kernel_regularizer and not tf.executing_eagerly():\n      # Keras has its own mechanism to handle regularization losses which does\n      # not use GraphKeys, but we want to also add losses to graph keys so they\n      # are easily accessable when layer is being used outside of Keras. Adding\n      # losses to GraphKeys will not interfer with Keras.\n      for reg in self.kernel_regularizer:\n        tf.compat.v1.add_to_collection(\n            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg(self.kernel))\n\n    # Constraints with enforce_strict_monotonicity always set to True. Intended\n    # to be run at the end of training or any time when you need everything to\n    # be strictly projected.\n    self._final_constraints = LatticeConstraints(\n        lattice_sizes=self.lattice_sizes,\n        monotonicities=self.monotonicities,\n        unimodalities=self.unimodalities,\n        edgeworth_trusts=self.edgeworth_trusts,\n        trapezoid_trusts=self.trapezoid_trusts,\n        monotonic_dominances=self.monotonic_dominances,\n        range_dominances=self.range_dominances,\n        joint_monotonicities=self.joint_monotonicities,\n        joint_unimodalities=self.joint_unimodalities,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        num_projection_iterations=20,\n        enforce_strict_monotonicity=True)\n\n    self.lattice_sizes_tensor = tf.constant(\n        self.lattice_sizes, dtype=tf.int32, name=LATTICE_SIZES_NAME)\n    super(Lattice, self).build(input_shape)\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    # Use control dependencies to save lattice sizes as graph constant for\n    # visualisation toolbox to be able to recover it from saved graph.\n    # Wrap this constant into pure op since in TF 2.0 there are issues passing\n    # tensors into control_dependencies.\n    with tf.control_dependencies([tf.identity(self.lattice_sizes_tensor)]):\n      if self.interpolation == \"simplex\":\n        return lattice_lib.evaluate_with_simplex_interpolation(\n            inputs=inputs,\n            kernel=self.kernel,\n            units=self.units,\n            lattice_sizes=self.lattice_sizes,\n            clip_inputs=self.clip_inputs)\n      elif self.interpolation == \"hypercube\":\n        return lattice_lib.evaluate_with_hypercube_interpolation(\n            inputs=inputs,\n            kernel=self.kernel,\n            units=self.units,\n            lattice_sizes=self.lattice_sizes,\n            clip_inputs=self.clip_inputs)\n      else:\n        raise ValueError(\"Unknown interpolation type: %s\" % self.interpolation)\n\n  def compute_output_shape(self, input_shape):\n    \"\"\"Standard Keras compute_output_shape() method.\"\"\"\n    if isinstance(input_shape, list):\n      input_shape = input_shape[0]\n    if self.units == 1:\n      return tuple(input_shape[:-1]) + (1,)\n    else:\n      # Second to last dimension must be equal to 'units'. Nothing to append.\n      return input_shape[:-1]\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"units\": self.units,\n        \"monotonicities\": self.monotonicities,\n        \"unimodalities\": self.unimodalities,\n        \"edgeworth_trusts\": self.edgeworth_trusts,\n        \"trapezoid_trusts\": self.trapezoid_trusts,\n        \"monotonic_dominances\": self.monotonic_dominances,\n        \"range_dominances\": self.range_dominances,\n        \"joint_monotonicities\": self.joint_monotonicities,\n        \"joint_unimodalities\": self.joint_unimodalities,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"num_projection_iterations\": self.num_projection_iterations,\n        \"monotonic_at_every_step\": self.monotonic_at_every_step,\n        \"clip_inputs\": self.clip_inputs,\n        \"interpolation\": self.interpolation,\n        \"kernel_initializer\":\n            keras.initializers.serialize(\n                self.kernel_initializer, use_legacy_format=True),\n        \"kernel_regularizer\":\n            [keras.regularizers.serialize(r, use_legacy_format=True)\n             for r in self.kernel_regularizer],\n    }  # pyformat: disable\n    config.update(super(Lattice, self).get_config())\n    return config\n\n  def finalize_constraints(self):\n    \"\"\"Ensures layers weights strictly satisfy constraints.\n\n    Applies approximate projection to strictly satisfy specified constraints.\n    If `monotonic_at_every_step == True` there is no need to call this function.\n\n    Returns:\n      In eager mode directly updates weights and returns variable which stores\n      them. In graph mode returns `assign_add` op which has to be executed to\n      updates weights.\n    \"\"\"\n    return self.kernel.assign_add(\n        self._final_constraints(self.kernel) - self.kernel)\n\n  def assert_constraints(self, eps=1e-6):\n    \"\"\"Asserts that weights satisfy all constraints.\n\n    In graph mode builds and returns list of assertion ops.\n    In eager mode directly executes assertions.\n\n    Args:\n      eps: allowed constraints violation.\n\n    Returns:\n      List of assertion ops in graph mode or immediately asserts in eager mode.\n    \"\"\"\n    return lattice_lib.assert_constraints(\n        weights=self.kernel,\n        lattice_sizes=self.lattice_sizes,\n        monotonicities=utils.canonicalize_monotonicities(\n            self.monotonicities, allow_decreasing=False),\n        edgeworth_trusts=utils.canonicalize_trust(self.edgeworth_trusts),\n        trapezoid_trusts=utils.canonicalize_trust(self.trapezoid_trusts),\n        monotonic_dominances=self.monotonic_dominances,\n        range_dominances=self.range_dominances,\n        joint_monotonicities=self.joint_monotonicities,\n        joint_unimodalities=self.joint_unimodalities,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        eps=eps)\n\n\ndef create_kernel_initializer(kernel_initializer_id,\n                              lattice_sizes,\n                              monotonicities,\n                              output_min,\n                              output_max,\n                              unimodalities,\n                              joint_unimodalities,\n                              init_min=None,\n                              init_max=None):\n  \"\"\"Returns a kernel Keras initializer object from its id.\n\n  This function is used to convert the 'kernel_initializer' parameter in the\n  constructor of tfl.Lattice into the corresponding initializer object.\n\n  Args:\n    kernel_initializer_id: See the documentation of the 'kernel_initializer'\n      parameter in the constructor of tfl.Lattice.\n    lattice_sizes: See the documentation of the same parameter in the\n      constructor of tfl.Lattice.\n    monotonicities: See the documentation of the same parameter in the\n      constructor of tfl.Lattice.\n    output_min: See the documentation of the same parameter in the constructor\n      of tfl.Lattice.\n    output_max: See the documentation of the same parameter in the constructor\n      of tfl.Lattice.\n    unimodalities: See the documentation of the same parameter in the\n      constructor of tfl.Lattice.\n    joint_unimodalities: See the documentation of the same parameter in the\n      constructor of tfl.Lattice.\n    init_min: None or lower bound of kernel initialization. If set, init_max\n      must also be set.\n    init_max: None or upper bound of kernel initialization. If set, init_min\n      must also be set.\n\n  Returns:\n    The Keras initializer object for the tfl.Lattice kernel variable.\n\n  Raises:\n    ValueError: If only one of init_{min/max} is set.\n  \"\"\"\n  if ((init_min is not None and init_max is None) or\n      (init_min is None and init_max is not None)):\n    raise ValueError(\"Both or neither of init_{min/max} must be set\")\n\n  def do_joint_unimodalities_contain_all_features(joint_unimodalities):\n    if (joint_unimodalities is None) or (len(joint_unimodalities) != 1):\n      return False\n    [joint_unimodalities] = joint_unimodalities\n    return set(joint_unimodalities[0]) == set(range(len(lattice_sizes)))\n\n  # Initialize joint unimodalities identical to regular ones.\n  all_unimodalities = [0] * len(lattice_sizes)\n  if unimodalities:\n    for i, value in enumerate(unimodalities):\n      if value:\n        all_unimodalities[i] = value\n  if joint_unimodalities:\n    for dimensions, direction in joint_unimodalities:\n      for dim in dimensions:\n        all_unimodalities[dim] = direction\n\n  if kernel_initializer_id in [\"linear_initializer\", \"LinearInitializer\"]:\n    if init_min is None and init_max is None:\n      init_min, init_max = lattice_lib.default_init_params(\n          output_min, output_max)\n\n    return LinearInitializer(\n        lattice_sizes=lattice_sizes,\n        monotonicities=monotonicities,\n        output_min=init_min,\n        output_max=init_max,\n        unimodalities=all_unimodalities)\n  elif kernel_initializer_id in [\n      \"random_monotonic_initializer\", \"RandomMonotonicInitializer\"\n  ]:\n    if init_min is None and init_max is None:\n      init_min, init_max = lattice_lib.default_init_params(\n          output_min, output_max)\n\n    return RandomMonotonicInitializer(\n        lattice_sizes=lattice_sizes,\n        output_min=init_min,\n        output_max=init_max,\n        unimodalities=all_unimodalities)\n  elif kernel_initializer_id in [\n      \"random_uniform_or_linear_initializer\", \"RandomUniformOrLinearInitializer\"\n  ]:\n    if do_joint_unimodalities_contain_all_features(joint_unimodalities):\n      return create_kernel_initializer(\"random_uniform\", lattice_sizes,\n                                       monotonicities, output_min, output_max,\n                                       unimodalities, joint_unimodalities,\n                                       init_min, init_max)\n    return create_kernel_initializer(\"linear_initializer\", lattice_sizes,\n                                     monotonicities, output_min, output_max,\n                                     unimodalities, joint_unimodalities,\n                                     init_min, init_max)\n  else:\n    # This is needed for Keras deserialization logic to be aware of our custom\n    # objects.\n    with keras.utils.custom_object_scope({\n        \"LinearInitializer\": LinearInitializer,\n        \"RandomMonotonicInitializer\": RandomMonotonicInitializer,\n    }):\n      return keras.initializers.get(kernel_initializer_id)\n\n\nclass LinearInitializer(keras.initializers.Initializer):\n  # pyformat: disable\n  \"\"\"Initializes a `tfl.layers.Lattice` as linear function.\n\n  - The linear function will have positive coefficients for monotonic dimensions\n    and 0 otherwise. If all dimensions are unconstrained, all coefficients will\n    be positive.\n  - Linear coefficients are set such that the minimum/maximum output of the\n    lattice matches the given output_min/output_max.\n  - Each monotonic dimension contributes with same weight regardless of number\n    of vertices per dimension.\n  - No dimension can be both monotonic and unimodal.\n  - Unimodal dimensions contribute with same weight as monotonic dimensions.\n  - Unimodal dimensions linearly decrease for first `(dim_size + 1) // 2`\n    vertices and then linearly increase for following vertices.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               lattice_sizes,\n               monotonicities,\n               output_min,\n               output_max,\n               unimodalities=None):\n    \"\"\"Initializes an instance of `LinearInitializer`.\n\n    Args:\n      lattice_sizes: Lattice sizes of `tfl.layers.Lattice` to initialize.\n      monotonicities: Monotonic dimensions for initialization. Does not need to\n        match `monotonicities` of `tfl.layers.Lattice`.\n      output_min: Minimum layer output after initialization.\n      output_max: Maximum layer output after initialization.\n      unimodalities: None or unimodal dimensions after initialization. Does not\n        need to match `unimodalities` of `tfl.layers.Lattice`.\n\n    Raises:\n      ValueError: If there is a mismatch between `monotonicities` and\n      `lattice_sizes`.\n    \"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        monotonicities=monotonicities,\n        unimodalities=unimodalities,\n        output_min=output_min,\n        output_max=output_max)\n\n    self.lattice_sizes = lattice_sizes\n    self.monotonicities = monotonicities\n    self.output_min = output_min\n    self.output_max = output_max\n    self.unimodalities = unimodalities\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    \"\"\"Returns weights of `tfl.layers.Lattice` layer.\n\n    Args:\n      shape: Must be: `(prod(lattice_sizes), units)`.\n      dtype: Standard Keras initializer param.\n      partition_info: Standard Keras initializer param. Not used.\n    \"\"\"\n    # TODO: figure out whether it should be used.\n    del partition_info\n    return lattice_lib.linear_initializer(\n        lattice_sizes=self.lattice_sizes,\n        monotonicities=utils.canonicalize_monotonicities(\n            self.monotonicities, allow_decreasing=False),\n        unimodalities=utils.canonicalize_unimodalities(self.unimodalities),\n        output_min=self.output_min,\n        output_max=self.output_max,\n        units=shape[1],\n        dtype=dtype)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"monotonicities\": self.monotonicities,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"unimodalities\": self.unimodalities,\n    }  # pyformat: disable\n    return config\n\n\nclass RandomMonotonicInitializer(keras.initializers.Initializer):\n  # pyformat: disable\n  \"\"\"Initializes a `tfl.layers.Lattice` as uniform random monotonic function.\n\n  - The uniform random monotonic function will initilaize the lattice parameters\n    uniformly at random and make it such that the parameters are monotonically\n    increasing for each input.\n  - The random parameters will be sampled from `[output_min, output_max]`\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, lattice_sizes, output_min, output_max, unimodalities=None):\n    \"\"\"Initializes an instance of `RandomMonotonicInitializer`.\n\n    Args:\n      lattice_sizes: Lattice sizes of `tfl.layers.Lattice` to initialize.\n      output_min: Minimum layer output after initialization.\n      output_max: Maximum layer output after initialization.\n      unimodalities: None or unimodal dimensions after initialization. Does not\n        need to match `unimodalities` of `tfl.layers.Lattice`.\n\n    Raises:\n      ValueError: If there are invalid hyperparameters.\n    \"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        unimodalities=unimodalities,\n        output_min=output_min,\n        output_max=output_max)\n\n    self.lattice_sizes = lattice_sizes\n    self.output_min = output_min\n    self.output_max = output_max\n    self.unimodalities = unimodalities\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    \"\"\"Returns weights of `tfl.layers.Lattice` layer.\n\n    Args:\n      shape: Must be: `(prod(lattice_sizes), units)`.\n      dtype: Standard Keras initializer param.\n      partition_info: Standard Keras initializer param. Not used.\n    \"\"\"\n    del partition_info\n    return lattice_lib.random_monotonic_initializer(\n        lattice_sizes=self.lattice_sizes,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        units=shape[1],\n        dtype=dtype)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"unimodalities\": self.unimodalities,\n    }  # pyformat: disable\n    return config\n\n\nclass LatticeConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Constraints for `tfl.layers.Lattice` layer.\n\n  Applies all constraints to the lattice weights. See `tfl.layers.Lattice`\n  for more details.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               lattice_sizes,\n               monotonicities=None,\n               unimodalities=None,\n               edgeworth_trusts=None,\n               trapezoid_trusts=None,\n               monotonic_dominances=None,\n               range_dominances=None,\n               joint_monotonicities=None,\n               joint_unimodalities=None,\n               output_min=None,\n               output_max=None,\n               num_projection_iterations=1,\n               enforce_strict_monotonicity=True):\n    \"\"\"Initializes an instance of `LatticeConstraints`.\n\n    Args:\n      lattice_sizes: Lattice sizes of `Lattice` layer to constraint.\n      monotonicities: Same meaning as corresponding parameter of `Lattice`.\n      unimodalities: Same meaning as corresponding parameter of `Lattice`.\n      edgeworth_trusts: Same meaning as corresponding parameter of `Lattice`.\n      trapezoid_trusts: Same meaning as corresponding parameter of `Lattice`.\n      monotonic_dominances: Same meaning as corresponding parameter of\n        `Lattice`.\n      range_dominances: Same meaning as corresponding parameter of `Lattice`.\n      joint_monotonicities: Same meaning as corresponding parameter of\n        `Lattice`.\n      joint_unimodalities: Same meaning as corresponding parameter of `Lattice`.\n      output_min: Minimum possible output.\n      output_max: Maximum possible output.\n      num_projection_iterations: Same meaning as corresponding parameter of\n        `Lattice`.\n      enforce_strict_monotonicity: Whether to use approximate projection to\n        ensure that constratins are strictly satisfied.\n\n    Raises:\n      ValueError: If weights to project don't correspond to `lattice_sizes`.\n    \"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        monotonicities=monotonicities,\n        unimodalities=unimodalities,\n        edgeworth_trusts=edgeworth_trusts,\n        trapezoid_trusts=trapezoid_trusts,\n        monotonic_dominances=monotonic_dominances,\n        range_dominances=range_dominances,\n        joint_monotonicities=joint_monotonicities,\n        joint_unimodalities=joint_unimodalities)\n\n    self.lattice_sizes = lattice_sizes\n    self.monotonicities = utils.canonicalize_monotonicities(\n        monotonicities, allow_decreasing=False)\n    self.unimodalities = utils.canonicalize_unimodalities(unimodalities)\n    self.edgeworth_trusts = utils.canonicalize_trust(edgeworth_trusts)\n    self.trapezoid_trusts = utils.canonicalize_trust(trapezoid_trusts)\n    self.monotonic_dominances = monotonic_dominances\n    self.range_dominances = range_dominances\n    self.joint_monotonicities = joint_monotonicities\n    self.joint_unimodalities = joint_unimodalities\n    self.output_min = output_min\n    self.output_max = output_max\n    self.num_projection_iterations = num_projection_iterations\n    self.enforce_strict_monotonicity = enforce_strict_monotonicity\n    self.num_constraint_dims = utils.count_non_zeros(self.monotonicities,\n                                                     self.unimodalities)\n\n  def __call__(self, w):\n    \"\"\"Applies constraints to `w`.\"\"\"\n    # No need to separately check for trust constraints and monotonic dominance,\n    # since monotonicity is required to impose them. The only exception is joint\n    # monotonicity.\n    if (self.num_constraint_dims > 0 or self.joint_monotonicities or\n        self.joint_unimodalities):\n      w = lattice_lib.project_by_dykstra(\n          w,\n          lattice_sizes=self.lattice_sizes,\n          monotonicities=self.monotonicities,\n          unimodalities=self.unimodalities,\n          edgeworth_trusts=self.edgeworth_trusts,\n          trapezoid_trusts=self.trapezoid_trusts,\n          monotonic_dominances=self.monotonic_dominances,\n          range_dominances=self.range_dominances,\n          joint_monotonicities=self.joint_monotonicities,\n          joint_unimodalities=self.joint_unimodalities,\n          num_iterations=self.num_projection_iterations)\n      if self.enforce_strict_monotonicity:\n        w = lattice_lib.finalize_constraints(\n            w,\n            lattice_sizes=self.lattice_sizes,\n            monotonicities=self.monotonicities,\n            edgeworth_trusts=self.edgeworth_trusts,\n            trapezoid_trusts=self.trapezoid_trusts,\n            output_min=self.output_min,\n            output_max=self.output_max)\n    # TODO: come up with a better solution than separately applying\n    # bounds again after other projections.\n    if self.output_min is not None:\n      w = tf.maximum(w, self.output_min)\n    if self.output_max is not None:\n      w = tf.minimum(w, self.output_max)\n    return w\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"monotonicities\": self.monotonicities,\n        \"unimodalities\": self.unimodalities,\n        \"edgeworth_trusts\": self.edgeworth_trusts,\n        \"trapezoid_trusts\": self.trapezoid_trusts,\n        \"monotonic_dominances\": self.monotonic_dominances,\n        \"range_dominances\": self.range_dominances,\n        \"joint_monotonicities\": self.joint_monotonicities,\n        \"joint_unimodalities\": self.joint_unimodalities,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"num_projection_iterations\": self.num_projection_iterations,\n        \"enforce_strict_monotonicity\": self.enforce_strict_monotonicity\n    }  # pyformat: disable\n\n\nclass TorsionRegularizer(keras.regularizers.Regularizer):\n  # pyformat: disable\n  \"\"\"Torsion regularizer for `tfl.layers.Lattice` layer.\n\n  Lattice torsion regularizer penalizes how much the lattice function twists\n  from side-to-side (see\n  [publication](http://jmlr.org/papers/v17/15-243.html)).\n\n  Consider a 3 x 2 lattice with weights `w`:\n\n  ```\n  w[3]-----w[4]-----w[5]\n    |        |        |\n    |        |        |\n  w[0]-----w[1]-----w[2]\n  ```\n\n  In this case, the torsion regularizer is defined as:\n\n  ```\n  l1 * (|w[4] + w[0] - w[3] - w[1]| + |w[5] + w[1] - w[4] - w[2]|) +\n  l2 * ((w[4] + w[0] - w[3] - w[1])^2 + (w[5] + w[1] - w[4] - w[2])^2)\n  ```\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, lattice_sizes, l1=0.0, l2=0.0):\n    \"\"\"Initializes an instance of `TorsionRegularizer`.\n\n    Args:\n      lattice_sizes: Lattice sizes of `tfl.layers.Lattice` to regularize.\n      l1: l1 regularization amount. Either single float or list or tuple of\n        floats to specify different regularization amount per dimension. The\n        amount of regularization for the interaction term between two dimensions\n        is the product of the corresponding per dimension amounts.\n      l2: l2 regularization amount. Either single float or list or tuple of\n        floats to specify different regularization amount per dimension. The\n        amount of regularization for the interaction term between two dimensions\n        is the product of the corresponding per dimension amounts.\n    \"\"\"\n    self.lattice_sizes = lattice_sizes\n    self.l1 = l1\n    self.l2 = l2\n\n  def __call__(self, x):\n    \"\"\"Returns regularization loss for `x`.\"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=self.lattice_sizes, weights_shape=x.shape)\n    return lattice_lib.torsion_regularizer(x, self.lattice_sizes, self.l1,\n                                           self.l2)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"l1\": self.l1,\n        \"l2\": self.l2,\n    }  # pyformat: disable\n\n\nclass LaplacianRegularizer(keras.regularizers.Regularizer):\n  # pyformat: disable\n  \"\"\"Laplacian regularizer for `tfl.layers.Lattice` layer.\n\n  Laplacian regularizer penalizes the difference between adjacent vertices in\n  multi-cell lattice (see\n  [publication](http://jmlr.org/papers/v17/15-243.html)).\n\n  Consider a 3 x 2 lattice with weights `w`:\n\n  ```\n  w[3]-----w[4]-----w[5]\n    |        |        |\n    |        |        |\n  w[0]-----w[1]-----w[2]\n  ```\n\n  where the number at each node represents the weight index.\n  In this case, the laplacian regularizer is defined as:\n\n  ```\n  l1[0] * (|w[1] - w[0]| + |w[2] - w[1]| +\n           |w[4] - w[3]| + |w[5] - w[4]|) +\n  l1[1] * (|w[3] - w[0]| + |w[4] - w[1]| + |w[5] - w[2]|) +\n\n  l2[0] * ((w[1] - w[0])^2 + (w[2] - w[1])^2 +\n           (w[4] - w[3])^2 + (w[5] - w[4])^2) +\n  l2[1] * ((w[3] - w[0])^2 + (w[4] - w[1])^2 + (w[5] - w[2])^2)\n  ```\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, lattice_sizes, l1=0.0, l2=0.0):\n    \"\"\"Initializes an instance of `LaplacianRegularizer`.\n\n    Args:\n      lattice_sizes: Lattice sizes of `tfl.layers.Lattice` to regularize.\n      l1: l1 regularization amount. Either single float or list or tuple of\n        floats to specify different regularization amount per dimension.\n      l2: l2 regularization amount. Either single float or list or tuple of\n        floats to specify different regularization amount per dimension.\n\n    Raises:\n      ValueError: If provided input does not correspond to `lattice_sizes`.\n    \"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        regularization_amount=l1,\n        regularization_info=\"l1\")\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=lattice_sizes,\n        regularization_amount=l2,\n        regularization_info=\"l2\")\n    self.lattice_sizes = lattice_sizes\n    self.l1 = l1\n    self.l2 = l2\n\n  def __call__(self, x):\n    \"\"\"Returns regularization loss for `x`.\"\"\"\n    lattice_lib.verify_hyperparameters(\n        lattice_sizes=self.lattice_sizes, weights_shape=x.shape)\n    return lattice_lib.laplacian_regularizer(x, self.lattice_sizes, self.l1,\n                                             self.l2)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"lattice_sizes\": self.lattice_sizes,\n        \"l1\": self.l1,\n        \"l2\": self.l2\n    }  # pyformat: disable\n"
  },
  {
    "path": "tensorflow_lattice/python/lattice_lib.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implementation of algorithms required for Lattice layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport copy\nimport itertools\nimport math\n\nfrom . import utils\nfrom absl import logging\nimport numpy as np\nimport six\nimport tensorflow as tf\n\n\ndef evaluate_with_simplex_interpolation(inputs, kernel, units, lattice_sizes,\n                                        clip_inputs):\n  \"\"\"Evaluates a lattice using simplex interpolation.\n\n  Within each cell of the lattice, we partition the hypercube into d! simplices,\n  where each simplex has d+1 vertices. Each simplex (relative to the lower\n  corner of the hypercube) includes the all-zeros vertex, a vertex with a\n  single one, a vertex with two ones, ... and the all-ones vertex.\n  For example, for a three-dimensional unit hypercube the 3! = 6 simplices are:\n\n  [0,0,0], [0,0,1], [0,1,1], [1,1,1]\n  [0,0,0], [0,0,1], [1,0,1], [1,1,1]\n  [0,0,0], [0,1,0], [0,1,1], [1,1,1]\n  [0,0,0], [0,1,0], [1,1,0], [1,1,1]\n  [0,0,0], [1,0,0], [1,1,0], [1,1,1]\n  [0,0,0], [1,0,0], [1,0,1], [1,1,1]\n\n  A point x in the hypercube is contained in the simplex corresponding to the\n  order of x's components. For example, x = [0.4,0.2,0.8] is contained in the\n  simplex specified by [2,0,1] (second in the above list). The weight associated\n  with each vertex in the simplex is the difference between the decreasingly\n  sorted cooredinates of the input. For details, see e.g. \"Dissection of the\n  hypercube into simplices\", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979.\n\n  Args:\n    inputs: Tensor of shape: `(batch_size, ..., len(lattice_sizes))` or list of\n      `len(lattice_sizes)` tensors of same shape `(batch_size, ..., 1)` which\n      represents points to apply lattice interpolation to. A typical shape is\n      `(batch_size, len(lattice_sizes))`.\n    kernel: Lattice kernel of shape (num_params_per_lattice, units).\n    units: Output dimension of the lattice.\n    lattice_sizes: List or tuple of integers which represents lattice sizes of\n      layer for which interpolation is being computed.\n    clip_inputs: Whether inputs should be clipped to the input range of the\n      lattice.\n\n  Returns:\n    Tensor of shape: `(batch_size, ..., units)`.\n  \"\"\"\n  if isinstance(inputs, list):\n    inputs = tf.concat(inputs, axis=-1)\n\n  if clip_inputs:\n    inputs = _clip_onto_lattice_range(\n        inputs=inputs, lattice_sizes=lattice_sizes)\n\n  lattice_rank = len(lattice_sizes)\n  input_dim = len(inputs.shape)\n  all_size_2 = all(size == 2 for size in lattice_sizes)\n\n  # Strides are the changes in the global index (index into the flattened\n  # parameters) when moving across each dimension.\n  # E.g. for 2x2x2, strides are [4, 2, 1].\n  strides = tf.constant(\n      np.cumprod([1] + lattice_sizes[::-1][:-1])[::-1], tf.int32)\n\n  if not all_size_2:\n    # Find offset (into flattened parameters) for the lower corner of the\n    # hypercube where input lands in.\n    lower_corner_coordinates = tf.cast(inputs, tf.int32)\n    # Avoid the corner case of landing on the outermost edge.\n    lower_corner_coordinates = tf.minimum(lower_corner_coordinates,\n                                          np.array(lattice_sizes) - 2)\n\n    # Multiplying coordinates by strides and summing up gives out the index into\n    # the flattened parameter tensor.\n    # Note: Alternative method using tf.tensordot + tf.expand_dims is slower.\n    lower_corner_offset = tf.reduce_sum(\n        lower_corner_coordinates * strides, axis=-1, keepdims=True)\n\n    # Continue simplex interpolation with the residuals\n    inputs = inputs - tf.cast(lower_corner_coordinates, inputs.dtype)\n\n  # Get sorted values and indicies.\n  # TODO: investigate if there is a way to avoid sorting twice.\n  sorted_indices = tf.argsort(inputs, direction=\"DESCENDING\")\n  sorted_inputs = tf.sort(inputs, direction=\"DESCENDING\")\n\n  # Simplex interpolation weights are the deltas between residuals.\n  no_padding_dims = [[0, 0]] * (input_dim - 1)\n  sorted_inputs_padded_left = tf.pad(\n      sorted_inputs, no_padding_dims + [[1, 0]], constant_values=1.)\n  sorted_inputs_padded_right = tf.pad(\n      sorted_inputs, no_padding_dims + [[0, 1]], constant_values=0.)\n  weights = sorted_inputs_padded_left - sorted_inputs_padded_right\n\n  # Calculate cumsum over the strides of sorted dimensions to get index of\n  # simplex vertices into the flattened lattice parameters.\n  sorted_strides = tf.gather(strides, sorted_indices)\n  if all_size_2:\n    # Lower corner offset is 0 for 2^d lattices.\n    corner_offset_and_sorted_strides = tf.pad(sorted_strides,\n                                              no_padding_dims + [[1, 0]])\n  else:\n    corner_offset_and_sorted_strides = tf.concat(\n        [lower_corner_offset, sorted_strides], axis=-1)\n  indices = tf.cumsum(corner_offset_and_sorted_strides, axis=-1)\n\n  # Get parameters values of simplex indicies.\n  if units == 1:\n    gathered_params = tf.gather(tf.reshape(kernel, [-1]), indices)\n  else:\n    # We now have two tensors 'indices' and 'weights' of shape (batch, units).\n    # The kernel is of shape (num_params_per_lattice, units).\n    # In order to use tf.gather, we need to convert 'indices' so that they are\n    # indices into the flattened parameter tensor.\n    # Note: Alternative method that uses a transpose on the parameters instead\n    # of a multiply on the indices is slower with typical batch sizes.\n    unit_offset = tf.constant([[i] * (lattice_rank + 1) for i in range(units)])\n    flat_indices = indices * units + unit_offset\n    gathered_params = tf.gather(tf.reshape(kernel, [-1]), flat_indices)\n\n  # Dot product with interpolation weights.\n  # Note: Alternative method using tf.einsum is slightly slower on CPU.\n  return tf.reduce_sum(\n      tf.multiply(gathered_params, weights), axis=-1, keepdims=(units == 1))\n\n\ndef evaluate_with_hypercube_interpolation(inputs, kernel, units, lattice_sizes,\n                                          clip_inputs):\n  \"\"\"Evaluates a lattice using hypercube interpolation.\n\n  Lattice function is multi-linearly interpolated between the 2^d vertices of a\n  hypercube. This interpolation method is typically slower than simplex\n  interpolation, since each value is interpolated from 2^d hypercube corners,\n  rather than d+1 simplex corners. For details, see e.g. \"Dissection of the\n  hypercube into simplices\", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979.\n\n  Args:\n    inputs: Tensor representing points to apply lattice interpolation to. If\n      units = 1, tensor should be of shape: `(batch_size, ...,\n        len(lattice_sizes))` or list of `len(lattice_sizes)` tensors of same\n        shape `(batch_size, ..., 1)`.\n      If units > 1, tensor should be of shape: `(batch_size, ..., units,\n        len(lattice_sizes))` or list of `len(lattice_sizes)` tensors of same\n        shape `(batch_size, ..., units, 1)`. A typical shape is `(batch_size,\n        len(lattice_sizes))`.\n    kernel: Lattice kernel of shape (num_params_per_lattice, units).\n    units: Output dimension of the lattice.\n    lattice_sizes: List or tuple of integers which represents lattice sizes of\n      layer for which interpolation is being computed.\n    clip_inputs: Whether inputs should be clipped to the input range of the\n      lattice.\n\n  Returns:\n    Tensor of shape: `(batch_size, ..., units)`.\n  \"\"\"\n  interpolation_weights = compute_interpolation_weights(\n      inputs=inputs, lattice_sizes=lattice_sizes, clip_inputs=clip_inputs)\n\n  if units == 1:\n    # Weights shape: (batch-size, ..., prod(lattice_sizes))\n    # Kernel shape:  (prod(lattice_sizes), 1)\n    return tf.matmul(interpolation_weights, kernel)\n  else:\n    # Weights shape: (batch-size, ..., units, prod(lattice_sizes))\n    # Kernel shape:  (prod(lattice_sizes), units)\n    return tf.reduce_sum(interpolation_weights * tf.transpose(kernel), axis=-1)\n\n\n# TODO: Rename and update usage.\ndef compute_interpolation_weights(inputs, lattice_sizes, clip_inputs=True):\n  \"\"\"Computes weights for hypercube lattice interpolation.\n\n  Running time: `O(batch_size * prod(lattice_sizes))`\n\n  If `clip_inputs == True`, inputs outside of the range defined by\n  `lattice_sizes` will be clipped into the lattice input range. If not, the\n  corresponding weights will linearly approach 0.0 with input moving away from\n  the valid input range.\n\n  Args:\n    inputs: Tensor of shape: `(batch_size, ..., len(lattice_sizes))` or list of\n      `len(lattice_sizes)` tensors of same shape `(batch_size, ..., 1)` which\n      represents points to apply lattice interpolation to. A typical shape is\n      `(batch_size, len(lattice_sizes))`.\n    lattice_sizes: List or tuple of integers which represents lattice sizes of\n      layer for which interpolation is being computed.\n    clip_inputs: Whether inputs should be clipped to the input range of the\n      lattice.\n\n  Raises:\n    ValueError: If last dimension of `inputs` does not match `lattice_sizes`.\n\n  Returns:\n    Interpolation weights tensor of shape:\n    `(batch_size, ..., prod(lattice_sizes))`.\n  \"\"\"\n  if isinstance(inputs, list):\n    input_shape = [tensor.shape for tensor in inputs]\n    input_dtype = inputs[0].dtype\n  else:\n    input_shape = inputs.shape\n    input_dtype = inputs.dtype\n  verify_hyperparameters(lattice_sizes=lattice_sizes, input_shape=input_shape)\n\n  # Special case: 2^d lattice with input passed in as a single tensor\n  if all(size == 2 for size in lattice_sizes) and not isinstance(inputs, list):\n    w = tf.stack([(1.0 - inputs), inputs], axis=-1)\n    if clip_inputs:\n      w = tf.clip_by_value(w, clip_value_min=0, clip_value_max=1)\n    one_d_interpolation_weights = tf.unstack(w, axis=-2)\n    return batch_outer_operation(one_d_interpolation_weights, operation=\"auto\")\n\n  if clip_inputs:\n    inputs = _clip_onto_lattice_range(\n        inputs=inputs, lattice_sizes=lattice_sizes)\n\n  # Create interpolation keypoints in advance in order to reuse them for all\n  # dimensions of same size.\n  dim_keypoints = {}\n  for dim_size in set(lattice_sizes):\n    dim_keypoints[dim_size] = tf.constant([i for i in range(dim_size)],\n                                          dtype=input_dtype)\n\n  # Bucketize in order to share interpolation ops across consequtive dims of\n  # same size.\n  bucketized_inputs = _bucketize_consequtive_equal_dims(\n      inputs=inputs, lattice_sizes=lattice_sizes)\n\n  one_d_interpolation_weights = []\n  for tensor, bucket_size, dim_size in bucketized_inputs:\n    if bucket_size > 1:\n      # Within bucket all dims have same lattice sizes so instead of splitting\n      # before interpolation we split after interpolation.\n      # Expand dims in order to make interpolation through broadcasting work.\n      tensor = tf.expand_dims(tensor, axis=-1)\n\n    # Broadcasting subtraction op.\n    distance = tf.abs(tensor - dim_keypoints[dim_size])\n    # Following ops will do following:\n    # 1) if distance >= 1.0 then set interpolation weight to 0.0.\n    # 2) if distance < 1.0 then set interpolation weight to 1.0 - distance.\n    weights = 1.0 - tf.minimum(distance, 1.0)\n\n    if bucket_size == 1:\n      one_d_interpolation_weights.append(weights)\n    else:\n      one_d_interpolation_weights.extend(tf.unstack(weights, axis=-2))\n\n  return batch_outer_operation(one_d_interpolation_weights, operation=\"auto\")\n\n\ndef batch_outer_operation(list_of_tensors, operation=\"auto\"):\n  \"\"\"Computes outer operation of last dimensions of each of given tensors.\n\n  Args:\n    list_of_tensors: List of tensors of same shape `(batch_size, ..., k[i])`\n      where everything expect `k_i` matches.\n    operation: - binary TF operation which supports broadcasting to be applied.\n      - string \"auto\" in order to apply tf.multiply for first several tensors\n      and tf.matmul for remaining.\n\n  Returns:\n    Tensor of shape: `(batch_size, ..., mul_i(k[i]))`.\n  \"\"\"\n  # Alternative implementation using tf.einsum creates fewer graph nodes.\n  # This is slightly slower on CPU as of 2020/5, but the timing results might\n  # change with different setup/platform/hardware.\n  # Create a formula for outer product. e.g. '...a,...b,...c->...abc'\n  # if operation == \"auto\":\n  #   n = len(list_of_tensors)\n  #   chars = string.ascii_lowercase[:n]\n  #   eqn = \",\".join([\"...\" + c for c in chars]) + \"->...\" + \"\".join(chars)\n  #   result = tf.einsum(eqn, *list_of_tensors)\n  #   result_shape = [-1] + [int(size) for size in result.shape[1:]]\n  #   output_shape = result_shape[:-n] + [np.prod(result_shape[-n:])]\n  #   return tf.reshape(result, shape=output_shape)\n\n  if len(list_of_tensors) == 1:\n    return list_of_tensors[0]\n\n  # Dimensions of size '1' at position -1 of first tensor and -2 of second\n  # tensor will result in outer operation due to broadcasting.\n  result = tf.expand_dims(list_of_tensors[0], axis=-1)\n\n  for i, tensor in enumerate(list_of_tensors[1:]):\n    if operation == \"auto\":\n      # Threshold 6 determined empirically for 2^d lattices.\n      op = tf.multiply if i < 6 else tf.matmul\n    else:\n      op = operation\n\n    result = op(result, tf.expand_dims(tensor, axis=-2))\n\n    # For TF1 compatibility convert shape to integers allowing first dimension\n    # to be undefined.\n    #\n    # If we want to support arbitrary number of undefined dimensions we must\n    # compute new_shape using tf ops. It is undesireble because we want to\n    # minimize graph size.\n    shape = [-1] + [int(size) for size in result.shape[1:]]\n\n    # Merge last 2 dimensions which we just multiplied.\n    new_shape = shape[:-2] + [shape[-2] * shape[-1]]\n\n    # Since we are doing reshape anyway append 1 to prepare 'result' for\n    # following outer operation.\n    if i < len(list_of_tensors) - 2:\n      new_shape.append(1)\n\n    result = tf.reshape(result, shape=new_shape)\n  return result\n\n\ndef _clip_onto_lattice_range(inputs, lattice_sizes):\n  \"\"\"Clips inputs onto valid input range for given lattice_sizes.\n\n  Args:\n    inputs: `inputs` argument of `compute_interpolation_weights`.\n    lattice_sizes: list or tuple of integers which represents lattice sizes to\n      clip onto.\n\n  Returns:\n    Clipped `inputs`.\n  \"\"\"\n  if not isinstance(inputs, list):\n    upper_bounds = [dim_size - 1.0 for dim_size in lattice_sizes]\n    return tf.clip_by_value(\n        inputs,\n        clip_value_min=tf.zeros(shape=len(lattice_sizes), dtype=inputs.dtype),\n        clip_value_max=tf.constant(upper_bounds, dtype=inputs.dtype))\n  else:\n    # Share bound constant across dimensions of same size.\n    dim_upper_bounds = {}\n    for dim_size in set(lattice_sizes):\n      dim_upper_bounds[dim_size] = tf.constant(\n          dim_size - 1.0, dtype=inputs[0].dtype)\n    dim_lower_bound = tf.zeros(shape=[], dtype=inputs[0].dtype)\n\n    clipped_inputs = []\n    for one_d_input, dim_size in zip(inputs, lattice_sizes):\n      clipped_inputs.append(\n          tf.clip_by_value(\n              one_d_input,\n              clip_value_min=dim_lower_bound,\n              clip_value_max=dim_upper_bounds[dim_size]))\n    return clipped_inputs\n\n\ndef _bucketize_consequtive_equal_dims(inputs, lattice_sizes):\n  \"\"\"Groups consequite dimensions of same size together.\n\n  For example `lattice_sizes == [2, 2, 2, 5, 5, 2]` produce 3 buckets:\n  - bucket of size 3 which corresponds to first group of dimensions of size 2.\n  - bucket of size 2 which corresponds to group of dimensions of size 5.\n  - bucket of size 1 which corresponds to last dimension of size 2.\n  If `inputs` is a single tensor then it will be split accordig to buckets.\n\n  If `inputs` is a list of tensor then all buckets will be of size 1 regardless\n  of lattice sizes in order to avoid merging tensors. In this case function acts\n  merely as a convenience helper to unify output format.\n\n  Args:\n    inputs: `inputs` argument of `compute_interpolation_weights`.\n    lattice_sizes: list or tuple of integers which represents lattice sizes.\n\n  Returns:\n    Iterable of tuples: `(tensor, bucket_size, bucket_dim_size)` where\n    `tensor.shape[-1] == bucket_size` and `bucket_dim_size` is a lattice size\n    which corresponds to bucket.\n  \"\"\"\n  if not isinstance(inputs, list):\n    bucket_sizes = []\n    bucket_dim_sizes = []\n    current_size = 1\n    for i in range(1, len(lattice_sizes)):\n      if lattice_sizes[i] != lattice_sizes[i - 1]:\n        bucket_sizes.append(current_size)\n        bucket_dim_sizes.append(lattice_sizes[i - 1])\n        current_size = 1\n      else:\n        current_size += 1\n    bucket_sizes.append(current_size)\n    bucket_dim_sizes.append(lattice_sizes[-1])\n    inputs = tf.split(inputs, num_or_size_splits=bucket_sizes, axis=-1)\n  else:\n    # TODO: run benchmark and figure out whether it make sense to merge\n    # indiviaul tensors here.\n    bucket_sizes = [1] * len(lattice_sizes)\n    bucket_dim_sizes = lattice_sizes\n  return zip(inputs, bucket_sizes, bucket_dim_sizes)\n\n\ndef default_init_params(output_min, output_max):\n  \"\"\"Returns reasonable default parameters if not defined explicitly.\n\n  Args:\n    output_min: None or minimum layer output.\n    output_max: None or maximum layer output.\n  \"\"\"\n  if output_min is not None:\n    init_min = output_min\n  elif output_max is not None:\n    init_min = min(0.0, output_max)\n  else:\n    init_min = 0.0\n\n  if output_max is not None:\n    init_max = output_max\n  elif output_min is not None:\n    init_max = max(1.0, output_min)\n  else:\n    init_max = 1.0\n\n  # Return our min and max.\n  return init_min, init_max\n\n\ndef linear_initializer(lattice_sizes,\n                       output_min,\n                       output_max,\n                       monotonicities=None,\n                       unimodalities=None,\n                       units=1,\n                       dtype=tf.float32):\n  \"\"\"Returns a lattice layer weight tensor that represents a linear function.\n\n  - The linear function will have positive coefficients for monotonic dimensions\n    and 0 otherwise. If all dimensions are unconstrained, all coefficients will\n    be positive.\n  - Linear coefficients are set such that the minimum/maximum output of the\n    lattice matches the given output_min/output_max.\n  - Each monotonic dimension contributes with same weight regardless of number\n    of vertices per dimension.\n  - No dimension can be both monotonic and unimodal.\n  - Unimodal dimensions contribute with same weight as monotonic dimensions.\n  - Unimodal dimensions linearly decrease for first `(dim_size + 1) // 2`\n    vertices and then linearly increase for following vertices.\n\n  Args:\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n    output_min: Minimum output of lattice layer after initialization.\n    output_max: Maximum output of lattice layer after initialization.\n    monotonicities: None or list or tuple of same length as lattice_sizes of {0,\n      1} which represents monotonicity constraints per dimension. 1 stands for\n      increasing (non-decreasing in fact), 0 for no monotonicity constraints.\n    unimodalities: None or list or tuple of same length as lattice_sizes of {-1,\n      0, 1} which represents unimodality constraints per dimension. 1 indicates\n      that function first decreases then increases, -1 indicates that function\n      first increases then decreases, 0 indicates no unimodality constraints.\n    units: Output dimension of the layer. Each of units lattices will be\n      initialized identically.\n    dtype: dtype.\n\n  Returns:\n    Lattice weights tensor of shape: `(prod(lattice_sizes), units)`.\n  \"\"\"\n  verify_hyperparameters(\n      lattice_sizes=lattice_sizes,\n      monotonicities=monotonicities,\n      unimodalities=unimodalities)\n  if monotonicities is None:\n    monotonicities = [0] * len(lattice_sizes)\n  if unimodalities is None:\n    unimodalities = [0] * len(lattice_sizes)\n\n  num_constraint_dims = utils.count_non_zeros(monotonicities, unimodalities)\n  if num_constraint_dims == 0:\n    monotonicities = [1] * len(lattice_sizes)\n    num_constraint_dims = len(lattice_sizes)\n\n  dim_range = float(output_max - output_min) / num_constraint_dims\n  one_d_weights = []\n\n  for monotonicity, unimodality, dim_size in zip(monotonicities, unimodalities,\n                                                 lattice_sizes):\n    if monotonicity != 0:\n      one_d = _linspace(start=0.0, stop=dim_range, num=dim_size)\n    elif unimodality != 0:\n      decreasing = _linspace(start=dim_range, stop=0.0, num=(dim_size + 1) // 2)\n      increasing = _linspace(start=0.0, stop=dim_range, num=(dim_size + 1) // 2)\n      # For odd size dimensions we want just 1 extreme point. For even sized we\n      # want 2.\n      if unimodality == 1:\n        one_d = decreasing + increasing[dim_size % 2:]\n      else:\n        one_d = increasing + decreasing[dim_size % 2:]\n    else:\n      one_d = [0.0] * dim_size\n    # Insert batch dim of size 1 at the beginning for batch_outer_operation.\n    one_d_weights.append(tf.constant(one_d, dtype=dtype, shape=[1, dim_size]))\n\n  # Use same implementation of outer operation as interpolation logic in order\n  # to guarantee same weights order.\n  weights = batch_outer_operation(one_d_weights, operation=tf.add)\n  weights = tf.reshape(weights + output_min, shape=[-1, 1])\n  if units > 1:\n    weights = tf.tile(weights, multiples=[1, units])\n  return weights\n\n\ndef _linspace(start, stop, num):\n  \"\"\"Returns `num` uniformly spaced floats between `start` and `stop`.\"\"\"\n  if num == 1:\n    return [start]\n  return [start + (stop - start) * i / (num - 1.0) for i in range(num)]\n\n\ndef random_monotonic_initializer(lattice_sizes,\n                                 output_min,\n                                 output_max,\n                                 units=1,\n                                 dtype=tf.float32):\n  \"\"\"Returns a uniformly random sampled monotonic lattice layer weight tensor.\n\n  - The uniform random monotonic function will initilaize the lattice parameters\n    uniformly at random and make it such that the parameters are monotonically\n    increasing for each input.\n  - The random parameters will be sampled from `[output_min, output_max]`\n\n  Args:\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n    output_min: Minimum output of lattice layer after initialization.\n    output_max: Maximum output of lattice layer after initialization.\n    units: Output dimension of the layer. Each of units lattices will be\n      initialized identically.\n    dtype: dtype.\n\n  Returns:\n    Lattice weights tensor of shape: `(prod(lattice_sizes), units)`.\n  \"\"\"\n  # First we verify parameters\n  verify_hyperparameters(lattice_sizes=lattice_sizes)\n\n  dimension = len(lattice_sizes)\n  # Pre-compute the bases of the global index for each dimension.\n  index_bases = [1] * dimension\n  for i in range(0, dimension - 1)[::-1]:\n    index_bases[i] = index_bases[i + 1] * lattice_sizes[i + 1]\n  total_lattice_size = np.prod(lattice_sizes)\n  # Create parameter indices to later gather parameter values in the proper\n  # ordering.\n  lattice_parameter_indices = [0] * total_lattice_size\n\n  # Starting from the all-0 vertex, expand new vertices by getting the vertices\n  # that are children of the vertices expanded in the last iteration in terms of\n  # monotonic dependencies. Create constant tensor representing order of init\n  # mapping each index to its corresponding random parameter value.\n  parameter_index = 1\n  # Vertices expanded in the last iteration.\n  last_vertices = [0]\n  while last_vertices:\n    new_vertices_set = set()\n    for index in last_vertices:\n      remaining_index = index\n      # For each dimension, if the vertex is not at the end of that dimension,\n      # we can create a child of the current vertex by increasing the value\n      # of the vertex in that dimension by one.\n      for i in range(dimension):\n        index_base = index_bases[i]\n        # The value of the vertex index in the i'th dimension\n        index_dim = remaining_index // index_base\n        if index_dim < lattice_sizes[i] - 1:\n          new_index = index + index_base\n          if new_index not in new_vertices_set:\n            new_vertices_set.add(new_index)\n        remaining_index = remaining_index % index_base\n    # Randomly sort the vertices expanded in the current iteration. Note that\n    # there can be no monotonic dependency between vertices expanded in the same\n    # iteration because their sum of all dimensions are the same (we increase\n    # them one-by-one in each iteration).\n    new_vertices = list(new_vertices_set)\n    np.random.shuffle(new_vertices)\n    # Assign parameter values\n    for vertex in new_vertices:\n      lattice_parameter_indices[vertex] = parameter_index\n      parameter_index += 1\n    last_vertices = new_vertices\n\n  # Convert lattice_parameter_indices into a tensor.\n  lattice_parameter_indices = tf.constant(lattice_parameter_indices)\n  # Uniformly generate the random parameter values.\n  parameter_values = tf.random.uniform(\n      shape=[total_lattice_size],\n      minval=output_min,\n      maxval=output_max,\n      dtype=dtype)\n  parameter_values = tf.sort(parameter_values)\n  # Convert lattice_parameter_indices to weights tensor and tile if necessary.\n  weights = tf.gather(parameter_values, lattice_parameter_indices)\n  weights = tf.reshape(weights, shape=[-1, 1])\n  if units > 1:\n    weights = tf.tile(weights, multiples=[1, units])\n  return weights\n\n\n# TODO: Add final projection for unimodality constraints.\ndef _approximately_project_monotonicity(weights, lattice_sizes, monotonicities):\n  \"\"\"Approximately projects to strictly meet monotonicity constraints.\n\n  Algorithm details:\n\n  Definition:\n  A[i] refer to i-th coordinate of vertex A.\n  For 2 vertices A and B:\n    \"A <p B\": if A[i] <= B[i] for all monotonic dimensions i. (aka dominated by\n      Pareto)\n\n  In order for lattice to be monotonic it is sufficient that either:\n    1) for any vertex V: weight[V] >= weight[X] for any vertex X that: X <p V.\n  or\n    2) for any vertex V: weight[V] <= weight[X] for any vertex X that: V <p X.\n\n  For example consider lattice:\n\n  ```\n  0---1---2---3\n  |   |   |   |\n  4---5---6---7\n  |   |   |   |\n  8---9---10--11\n  ```\n\n  For examle for vertex 6 it's sufficient that:\n\n  weight[6] >= max(weight[4, 5, 8, 9, 10])\n  Or:\n  weight[6] <= min(weight[2, 3, 7])\n\n  Given the above definition, we can use either of the following update rules to\n  approximately project into the feasible space:\n  max_proj[V] = max(weight[X]) for any X that: X <p V.\n  min_proj[V] = min(weight[X]) for any X that: V <p X.\n\n  It's clear though that these algorithms either only increase weights or only\n  decrease weights. We know that true projection algorithm increases some\n  weights and decreases others. To get closer to a true projection, we modify\n  and use both update rules as follows:\n\n  1) half_proj[V] = weight[V] + (max_proj[V] - weight[V]) / 2\n     ... move half way up towards max_proj.\n  2) min_max_proj[V] = min_proj[half_proj[V]]\n     ... move remained way down towards min_proj.\n\n  Differs from _project_partial_monotonicity in that this algorithm guarantees a\n  global satisfying solution for all monotonicity constraints.\n\n  Args:\n    weights: Tensor with weights whose shape matches lattice_sizes.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    monotonicities: List or tuple of same length as lattice_sizes of {0, 1}\n      which represents monotonicity constraints per dimension. 1 stands for\n      increasing (non-decreasing in fact), 0 for no monotonicity constraints.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  # To compute max_proj[V] for all V altogether compute cumulative maximum\n  # along every monotonic dimension in arbitrary order.\n  max_projection = weights\n  for dim in range(len(lattice_sizes)):\n    if monotonicities[dim] == 0:\n      continue\n    layers = tf.unstack(max_projection, axis=dim)\n    for i in range(1, len(layers)):\n      # Computing cummulative maximum.\n      layers[i] = tf.maximum(layers[i], layers[i - 1])\n    max_projection = tf.stack(layers, axis=dim)\n\n  half_projection = (weights + max_projection) / 2.0\n\n  min_projection = half_projection\n  for dim in range(len(lattice_sizes)):\n    if monotonicities[dim] == 0:\n      continue\n    layers = tf.unstack(min_projection, axis=dim)\n    for i in range(len(layers) - 2, -1, -1):\n      # Compute cumulitive minimum in reversed order compare to cumulative\n      # maximum above.\n      layers[i] = tf.minimum(layers[i], layers[i + 1])\n    min_projection = tf.stack(layers, axis=dim)\n\n  return min_projection\n\n\ndef _approximately_project_edgeworth(weights, lattice_sizes, units,\n                                     edgeworth_trusts):\n  \"\"\"Approximately projects to strictly meet all edgeworth trust constraints.\n\n  Note that this function will not introduce violations to any\n  previously-satisfied monotonicity constraints.\n\n  Algorithm details:\n\n  For a constraint on main dimension i and conditional dimension j, consider\n  some slice of weights that is fixed along all other dimensions, leaving a grid\n\n  ```\n  0---1---2---3\n  |   |   |   |\n  4---5---6---7\n  |   |   |   |\n  8---9---10--11\n  ```\n\n  You can think of all the other dimensions as other such grids stacked behind\n  this one, e.g. weight[8] and the points behind it are all such points with\n  index 0 in the i'th and j'th dimensions, and weight[6] and the points behind\n  it are all such points with index 2 in the i'th dimension and index 1 in the\n  j'th.\n\n  To enforce this edgeworth trust constraint without messing up monotonicity or\n  other trust constraints, the key idea is that we will always translate all\n  points 'behind' a point on this grid together. This ensures that no other\n  trust constraints will be violated, since all other weight differences\n  constrained by trust constraints will occur 'behind' a single such point\n  (no conditional feature can also be a main feature).\n\n  With that in mind, we project to edgeworth trust on this grid while\n  maintaining monotonicity by working up and right and always increasing the\n  top-right point in each four-point square. Here, we would first find how much\n  we need to increase weight[5] by to maintain edgeworth trust on {4,5,8,9}. To\n  follow the principle above, we then consider all such squares 'behind'\n  {4,5,8,9} and find the biggest such difference. weight[5] and all points\n  behind will be increased by that amount, and then we continue until fixing the\n  top-right grid, {2,3,6,7}.\n\n  If the trust constraint is in the opposite direction, i.e. cond_direction =\n  -1, repeat all of the above except that we start in the top-right {2,3,6,7}\n  grid and always lower the bottom-left point (weight[6] to start) until we\n  reach the bottom-left {4,5,8,9} grid.\n\n  Differs from _project_partial_edgeworth in that this algorithm guarantees a\n  global satisfying solution for all edgeworth trust constraints.\n\n  Args:\n    weights: Tensor with weights whose shape matches lattice_sizes\n      plus units if units > 1.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    units: Output dimension of the lattice.\n    edgeworth_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust: 1 if\n        higher values of the conditional feature should increase trust in the\n        main feature and -1 otherwise.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  # Project onto trust constraints by cumulatively fixing violations.\n  trust_projection = weights\n  for main_dim, cond_dim, cond_direction in edgeworth_trusts or []:\n    layers = _unstack_nd(trust_projection, [main_dim, cond_dim])\n    # Unlike other trust projections, cannot simply reverse layers beforehand\n    # based on cond_direction; asymmetry would break algorithm.\n    dims = len(layers[0][0].shape)\n    axis = (tf.constant(list(range(dims - 1)), dtype=tf.int32) if units > 1\n            else None)\n    if cond_direction > 0:\n      for i in range(0, lattice_sizes[main_dim] - 1):\n        for j in range(0, lattice_sizes[cond_dim] - 1):\n          difference_in_slopes = ((layers[i + 1][j] - layers[i][j]) -\n                                  (layers[i + 1][j + 1] - layers[i][j + 1]))\n          # Move all weights by the value of the biggest violation to both\n          # satisfy this constraint and not hurt others. See function comments\n          # for more details.\n          max_violation = tf.maximum(\n              tf.reduce_max(difference_in_slopes, axis=axis), 0)\n          layers[i + 1][j + 1] += max_violation\n    else:\n      for i in range(lattice_sizes[main_dim] - 2, -1, -1):\n        for j in range(lattice_sizes[cond_dim] - 2, -1, -1):\n          difference_in_slopes = ((layers[i + 1][j + 1] - layers[i][j + 1]) -\n                                  (layers[i + 1][j] - layers[i][j]))\n          max_violation = tf.maximum(\n              tf.reduce_max(difference_in_slopes, axis=axis), 0)\n          layers[i][j] -= max_violation\n    trust_projection = _stack_nd(layers, [main_dim, cond_dim])\n\n  return trust_projection\n\n\n# TODO: It is likely that this algorithm will work for all trapezoid\n# trust constraints without needing the reduce_max, as long as there are no\n# edgeworth constraints. If true, consider using that approach when possible.\ndef _approximately_project_trapezoid(weights, lattice_sizes, units,\n                                     trapezoid_trusts, edgeworth_trusts):\n  \"\"\"Approximately projects to strictly meet all trapezoid trust constraints.\n\n  Note that this function will not introduce violations to any\n  previously-satisfied monotonicity or edgeworth constraints.\n\n  Algorithm details:\n\n  For a constraint on main dimension i and conditional dimension j, consider\n  some slice of weights that is fixed along all other dimensions, leaving a grid\n\n  ```\n  0---1---2---3\n  |   |   |   |\n  4---5---6---7\n  |   |   |   |\n  8---9---10--11\n  ```\n\n  You can think of all the other dimensions as other such grids stacked behind\n  this one, e.g. weight[8] and the points behind it are all such points with\n  index 0 in the i'th and j'th dimensions, and weight[6] and the points behind\n  it are all such points with index 2 in the i'th dimension and index 1 in the\n  j'th.\n\n  We project to trapezoid trust on this grid by working up both edges of\n  the lattice and only ever decreasing weights on the low main_feature side and\n  increasing weights on the high main_feature side. In the above example, we\n  would first consider the pair {8, 4} and update weight 4 to be min(8, 4),\n  before then looking at {4, 0} and updating 0 to be min(4, 0). Similarly set\n  weight 7 to be max(7, 11) and then weight 3 to max(3, 7). Flip the orders if\n  cond_direction is -1: work down instead of up.\n\n  Unlike in the edgeworth trust case, we do not necessarily look 'behind' the\n  page and update all points behind a given grid point by the maximum violation\n  at each step. It turns out that while this does have the nice property of\n  maintaining almost all types of edgeworth constraints, for the same reason\n  that the edgeworth algorithm does (co-movement of weights involved in other\n  constraints), it can actually break other trapezoid constraints, namely those\n  which share the same conditional feature.\n\n  There is one exception, which is the matching edgeworth trust constraint. In\n  this case, the trapezoid updates only touch one corner of each edgeworth\n  constraint and so can violate them. The solution is to update by the max of\n  all violations behind the page and all violations encountered below in the\n  grid.\n\n  If you separately update each grid by the violations in that grid, this update\n  procedure turns out to respect all trapezoid constraints. The rationale is a\n  bit more subtle than in the edgeworth case. The basic idea is that since each\n  trapezoid and monotonicity constraint operates on two weights that are next to\n  each other (i.e. differ only in the index of one dimension), we can create\n  a 'square' of points in which one edge goes across the constraint we want to\n  maintain and the perpendicular edges go across the constraint we are updating.\n\n  For example, consider the 4 weights\n\n  ```\n  A -- B\n  |    |\n  C -- D\n  ```\n\n  A/B and C/D differ in the same one index (the constraint we hope to maintain)\n  while A/C and B/D differ across the conditional index of the trapezoid\n  constraint we are updating. Say we are focused on whether we maintain A'<=B'\n  (A' is A after imposing trapezoid trust) and we are operating on the 'min main\n  feature' side of the lattice so that any updates that occur will lower\n  weights. If B'=B after trapezoid trust, things are easy because A'<=A by 'min\n  main feature' and A<=B by the preexisting constraint. If not, and B'<B, we\n  start with A'<=C' by trapezoid trust and C'<=C by 'min main feature'. By\n  the preexisting constraints, C<=D, and by the trapezoid trust update procedure\n  and the fact that B has changed, it must be that B'=D.\n\n  Unfortunately, this algorithm will break edgeworth constraints.\n\n  The solution we take is to update independently for each grid whenever we have\n  only trapezoid constraints and to update with the max across all other\n  dimensions (and potentially below, in the case of matching constraints)\n  when there are both types of constraints, recognizing that in this second case\n  we may not achieve guarantees for trapezoid constraints which share a\n  conditional feature.\n\n  Differs from _project_partial_trapezoid in that this algorithm guarantees a\n  global satisfying solution for all trapezoid trust constraints.\n\n  Args:\n    weights: Tensor with weights whose shape matches lattice_sizes plus units\n      if units > 1.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    units: Output dimension of the lattice.\n    trapezoid_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust set to 1\n      if higher values of the conditional feature should increase trust in the\n      main feature and -1 otherwise.\n    edgeworth_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust set to 1\n      if higher values of the conditional feature should increase trust in the\n      main feature and -1 otherwise.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  any_edgeworth = bool(edgeworth_trusts)\n\n  # Project onto trust constraints by cumulatively fixing violations.\n  for main_dim, cond_dim, cond_direction in trapezoid_trusts or []:\n    layers = _unstack_nd(weights, [main_dim, cond_dim])\n    max_main_dim = lattice_sizes[main_dim] - 1\n    same_edgeworth = (main_dim, cond_dim,\n                      cond_direction) in set(edgeworth_trusts or [])\n    if cond_direction < 0:\n      layers = _reverse_second_list_dimension(layers)\n    lhs_update, rhs_update = 0, 0\n    for j in range(0, lattice_sizes[cond_dim] - 1):\n      lhs_difference = layers[0][j + 1] - layers[0][j]\n      lhs_update = _trapezoid_violation_update(lhs_difference, units,\n                                               any_edgeworth, same_edgeworth,\n                                               lhs_update)\n      layers[0][j + 1] -= lhs_update\n      rhs_difference = layers[max_main_dim][j] - layers[max_main_dim][j + 1]\n      rhs_update = _trapezoid_violation_update(rhs_difference, units,\n                                               any_edgeworth, same_edgeworth,\n                                               rhs_update)\n      layers[max_main_dim][j + 1] += rhs_update\n    if cond_direction < 0:\n      layers = _reverse_second_list_dimension(layers)\n    weights = _stack_nd(layers, [main_dim, cond_dim])\n\n  return weights\n\n\ndef _trapezoid_violation_update(differences, units, any_edgeworth,\n                                same_edgeworth, prior_update):\n  \"\"\"Calculates update amount based on violations for trapezoid projection.\n\n  Note that the shape of the returned tensor is different based on the value\n  of the any_edgeworth boolean feature. A single-valued tensor is\n  returned when it is true, representing the amount by which all relevant\n  weights will be updated. A tensor matching the shape of differences is\n  returned when it is false, representing the individual updates to be applied\n  to each relevant weight.\n\n  Args:\n    differences: Tensor containing amounts by which constraints are satisfied or\n      violated.\n    units: Output dimension of the lattice.\n    any_edgeworth: Boolean for whether any edgeworth trust constraints are set\n      for this lattice layer.\n    same_edgeworth: Boolean for whether there is a matching edgeworth constraint\n      for the trapezoid constraint being updated.\n    prior_update: Tensor containing previous trapezoid constraint update.\n\n  Returns:\n    Tensor either matching the shape of the input differences tensor or\n    consisting of a single element.\n\n  \"\"\"\n  dims = len(differences.shape) - 1\n  axis = tf.constant(list(range(dims)), dtype=tf.int32) if units > 1 else None\n  if any_edgeworth and same_edgeworth:\n    return tf.maximum(tf.maximum(\n        tf.reduce_max(differences, axis=axis), 0), prior_update)\n  elif any_edgeworth:\n    return tf.maximum(tf.reduce_max(differences, axis=axis), 0)\n  else:\n    return tf.maximum(differences, 0)\n\n\ndef _approximately_project_bounds(weights, units, output_min, output_max):\n  \"\"\"Approximately projects to strictly meet min/max constraints.\n\n  Note that this function will not introduce violations to any\n  previously-satisfied monotonicity or trust constraints.\n\n  Algorithm details:\n\n  The idea of the min/max projection is to evenly scale (squish) the weights\n  to fit within the desired range. This ensures that the weight differences-of-\n  differences encountered in the trust constraints will not be affected.\n\n  For example, given min_weight < output_min < 0 < output_max < max_weight, we\n  will translate all weights such that min_weight = 0, then scale the weights\n  by the difference in ratios between max_weight - min_weight and output_max -\n  output_min, and then translate back so that min_weight = output_min and\n  max_weight = output_max.\n\n  Args:\n    weights: Tensor with weights whose shape matches `lattice_sizes` plus units\n      if units > 1.\n    units: Output dimension of the lattice.\n    output_min: None or minimum possible output.\n    output_max: None or maximum possible output.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  # Project into [output_min, output_max] by translating and scaling output if\n  # necessary.\n  dims = len(weights.shape) - 1\n  axis = tf.constant(list(range(dims)), dtype=tf.int32) if units > 1 else None\n  final_projection = weights\n  if output_max is None and output_min is not None:\n    final_projection += tf.maximum(\n        output_min - tf.reduce_min(final_projection, axis=axis), 0)\n  elif output_max is not None and output_min is None:\n    final_projection -= tf.maximum(\n        tf.reduce_max(final_projection, axis=axis) - output_max, 0)\n  elif output_max is not None and output_min is not None:\n    max_violation = tf.maximum(\n        tf.reduce_max(final_projection, axis=axis) - output_max, 0)\n    min_violation = tf.maximum(\n        output_min - tf.reduce_min(final_projection, axis=axis), 0)\n    final_projection += (min_violation - output_min)\n    final_projection *= ((output_max - output_min) /\n                         ((output_max + max_violation) -\n                          (output_min - min_violation)))\n    final_projection += output_min\n  return final_projection\n\n\ndef finalize_constraints(weights,\n                         lattice_sizes,\n                         monotonicities,\n                         edgeworth_trusts=None,\n                         trapezoid_trusts=None,\n                         output_min=None,\n                         output_max=None):\n  \"\"\"Approximately projects lattice weights to strictly satisfy all constraints.\n\n  This projeciton guarantees that constraints are strictly met, but it is not\n  an exact projection w.r.t. the L2 norm. The computationally cost is\n  `O((num_monotonic_dims + num_trust_constraints) * num_lattice_weights)`.\n\n  See helper functions `_approximately_project_*` for details of the individual\n  projection algorithms for each set of constraints. They are designed to be\n  applied sequentially: monotonicity, then edgeworth, trapezoid, and bounds if\n  necessary. This is because the projection algorithms are guaranteed to not\n  violate *previous* constraints, though they may lead to violations of *later*\n  constraints.\n\n  Args:\n    weights: Lattice weights tensor of shape: `(prod(lattice_sizes), units)`.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    monotonicities: List or tuple of same length as lattice_sizes of {0, 1}\n      which represents monotonicity constraints per dimension. 1 stands for\n      increasing (non-decreasing in fact), 0 for no monotonicity constraints.\n    edgeworth_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust set to 1\n      if higher values of the conditional feature should increase trust in the\n      main feature and -1 otherwise.\n    trapezoid_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust set to 1\n      if higher values of the conditional feature should increase trust in the\n      main feature and -1 otherwise.\n    output_min: None or minimum possible output.\n    output_max: None or maximum possible output.\n\n  Returns:\n    Projected weights tensor of same shape as `weights`.\n  \"\"\"\n  if utils.count_non_zeros(monotonicities) == 0:\n    return weights\n  units = weights.shape[1]\n  if units > 1:\n    lattice_sizes = lattice_sizes + [int(units)]\n    if monotonicities:\n      monotonicities = monotonicities + [0]\n\n  weights = tf.reshape(weights, shape=lattice_sizes)\n\n  weights = _approximately_project_monotonicity(weights, lattice_sizes,\n                                                monotonicities)\n  if edgeworth_trusts or trapezoid_trusts:\n    weights = _approximately_project_edgeworth(weights, lattice_sizes, units,\n                                               edgeworth_trusts)\n    weights = _approximately_project_trapezoid(weights, lattice_sizes, units,\n                                               trapezoid_trusts,\n                                               edgeworth_trusts)\n    # Simple capping, applied in a later step, adds less distortion than this\n    # scaling projection; however, it could violate trust constraints.\n    weights = _approximately_project_bounds(weights, units, output_min,\n                                            output_max)\n  return tf.reshape(weights, shape=[-1, units])\n\n\n# TODO: approach used to implement regluarizers is likely to be more\n# efficient than one used here. Especially on TPU. Investigate it.\ndef _project_partial_monotonicity(weights, lattice_sizes, monotonicities,\n                                  unimodalities, dimension, constraint_group):\n  \"\"\"Applies exact monotonicity projection to a subset of a single dimension.\n\n  Algorithm details:\n\n  In order to project into k constrained dimensions we split all constraints\n  into 2k sets in such way that within each sets all constraints are\n  independent. These 2k sets are chosen in such way that for each constrained\n  dimension we have 2 sets of constraints: even and odd constraints according to\n  index of smallest vertex in constraint. We apply Dykstra's algorithm to these\n  sets handling each individual constraint within each set independently.\n\n  This function in particular, then, operates on one of these independent sets,\n  as defined by a specific dimension and constraint group: 0 for the even\n  constraints and 1 for the odd constraints.\n\n  Note that in case of just 2 lattice vertices per dimension odd set for that\n  dimension will be empty.\n\n  * k constrained dimensions projection:\n  If we know how to project into single constrained dimension then we can use\n  Dykstra algorithm to project into union of all k constrained dimensions.\n\n  * Single constrained dimension projection:\n  For single dimension projection we have multiple independent 1-d sequences of\n  constrained weights of same length.\n  For example 2 x 6 lattice with monotonicity along 2-nd dimension:\n\n  ```\n  0--<--1--<--2--<--3--<--4--<--5\n  |     |     |     |     |     |\n  6--<--7--<--8--<--9--<--10-<--11\n  ```\n\n  we have 2 independent rows of constraints. It's clear that both rows can be\n  projected independently.\n\n  To project 1 row, we can again apply Dykstra's algorithm splitting all\n  constraints into two sets: constraints with odd indices and constraints with\n  even indices. For example for first row:\n  - even constraints set: {0 < 1, 2 < 3, 4 < 5}\n  - odd constraints set:  {1 < 2, 3 < 4}\n\n  Within each set no constraints interact with each other so we can project\n  every individual constraint independently.\n\n  * Individual constraint projection:\n  Constraint weight[0] <= weight[1]:\n  - weight[0] = min(weight[0], (weight[0] + weight[1]) / 2)\n  - weight[1] = max(weight[1], (weight[0] + weight[1]) / 2)\n\n  Differs from _approximately_project_monotonicity in that this algorithm\n  - Only operates on a single dimension.\n  - Does not guarantee an satisfying solution to the full monotonicity\n    constraint.\n  - Exactly projects (in L2 terms) on the subset of constraints it does\n    operate on.\n\n  Args:\n    weights: Tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    monotonicities: None or list or tuple of same length as lattice_sizes of {0,\n      1} which represents monotonicity constraints per dimension. 1 stands for\n      increasing (non-decreasing in fact), 0 for no monotonicity constraints.\n    unimodalities: None or list or tuple of same length as lattice_sizes of {-1,\n      0, 1} which represents unimodality constraints per dimension. 1 indicates\n      that function first decreases then increases, -1 indicates that function\n      first increases then decreases, 0 indicates no unimodality constraints.\n    dimension: Index of feature to which we are applying constraints.\n    constraint_group: 0 or 1 as defined above, representing whether we are\n      operating on 'even' or 'odd' constraints.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n\n  Raises:\n    ValueError: If provided dimension has no monotonicity or unimodality\n      constraint associated with it.\n  \"\"\"\n\n  if monotonicities[dimension] == 0 and unimodalities[dimension] == 0:\n    raise ValueError(\n        \"Trying to project monotonicity and unimodality onto unconstrained \"\n        \"dimension: %d.\" % dimension)\n\n  layers = tf.unstack(weights, axis=dimension)\n  for i in range(constraint_group, lattice_sizes[dimension] - 1, 2):\n    # Project individual independent constraints.\n    average = (layers[i] + layers[i + 1]) / 2.0\n\n    if monotonicities[dimension] == 1:\n      layers[i] = tf.minimum(layers[i], average)\n      layers[i + 1] = tf.maximum(layers[i + 1], average)\n\n    if unimodalities[dimension] != 0:\n      is_first_part = (i < lattice_sizes[dimension] // 2)\n      if ((unimodalities[dimension] == -1 and is_first_part) or\n          (unimodalities[dimension] == 1 and not is_first_part)):\n        layers[i] = tf.minimum(layers[i], average)\n        layers[i + 1] = tf.maximum(layers[i + 1], average)\n      else:\n        layers[i] = tf.maximum(layers[i], average)\n        layers[i + 1] = tf.minimum(layers[i + 1], average)\n\n  return tf.stack(layers, axis=dimension)\n\n\ndef _project_partial_edgeworth(weights, lattice_sizes, edgeworth_trust,\n                               constraint_group):\n  \"\"\"Applies exact edgeworth trust projection to a subset of one constraint.\n\n  Algorithm details:\n\n  For the Edgeworth trust projection, we follow a similar approach to the\n  monotonicity projection by splitting up the constraints into independent sets.\n  Here, each trust constraint touches every lattice vertex, but can be broken up\n  into 4 independent sets of constraints, based on whether the constraint's\n  smaller indices along the main and conditional dimensions are even or odd.\n  That leaves us with 4t sets of constraints if we have t trust constraints,\n  which we can sequentially project onto with the Dykstra's algorithm.\n\n  This function applies to a single set of independent constraints within a\n  single trust constraint. The constraint group can take the value (0,0), (0,1),\n  (1,0), or (1,1) corresponding to even (0) or odd (1) for the main and\n  conditional dimensions, respectively.\n\n  * k trust constraints projection:\n  If we know how to project into single trust constraint then we can use\n  Dykstra algorithm to project into union of all k trust constraints.\n\n  * Single trust constraint projection:\n  Edgeworth constraints require the difference in weights across the main\n  feature to be larger when the conditional feature is higher. We can think of\n  this as separate constraints applied to each 'square' of weights {(i,j,...),\n  (i+1,j,...), (i,j+1,...), (i+1,j+1,...), where i and j denote the index\n  dimensions of the main and conditional features and the ellipses represent\n  a fixed value of the other feature dimensions. It is immediately clear that\n  we can apply the constraint at the same time for different values of the\n  other dimensions. Considering then a fixed slice, and a grid\n\n  ```\n  0---1---2---3\n  |   |   |   |\n  4---5---6---7\n  |   |   |   |\n  8---9---10--11\n  |   |   |   |\n  12--13--14--15\n  ```\n\n  we get our four independent sets by considering non-overlapping squares of\n  constraints. In particular, we define the sets by the combination of even &\n  odd starting indices in each dimension. So if we start our indexing at the\n  top-left, the even/even set would be the four squares {0,1,4,5}, {2,3,6,7},\n  {8,9,12,13}, and {10,11,14,15}, the even/odd set would be {4,5,8,9} and\n  {6,7,10,11} and so on.\n\n  * Individual weight projection:\n  Within each square the projection moves each of the four weights by the\n  constraint violation / 4, if necessary, increasing the gap between high-trust\n  weights across the main feature and decreasing the gap between low-trust\n  weights across the main feature.\n\n  Differs from _approximately_project_edgeworth in that this algorithm\n  - Only operates on the constraints for a single (main_dim, cond_dim) pair.\n  - Does not guarantee a satisfying solution to the full trust constraint.\n  - Exactly projects (in L2 terms) on the subset of constraints it does\n    operate on.\n\n  Args:\n    weights: Tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    edgeworth_trust: Three-element tuple representing a single trust constraint.\n      First element is the index of the main (monotonic) feature. Second element\n      is the index of the conditional feature. Third element is the direction of\n      trust set to 1 if higher values of the conditional feature increase trust\n      and -1 otherwise.\n    constraint_group: Two-element tuple of 0s and 1s as defined above,\n      representing the combination of 'even' and 'odd' constraints we are\n      projecting on.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  main_dim, cond_dim, cond_direction = edgeworth_trust\n  layers = _unstack_nd(weights, [main_dim, cond_dim])\n\n  if cond_direction < 0:\n    layers = _reverse_second_list_dimension(layers)\n  for i in range(constraint_group[0], lattice_sizes[main_dim] - 1, 2):\n    for j in range(constraint_group[1], lattice_sizes[cond_dim] - 1, 2):\n      difference_in_slopes = ((layers[i + 1][j] - layers[i][j]) -\n                              (layers[i + 1][j + 1] - layers[i][j + 1]))\n      correction = tf.maximum(difference_in_slopes / 4, 0)\n      layers[i][j] += correction\n      layers[i][j + 1] -= correction\n      layers[i + 1][j] -= correction\n      layers[i + 1][j + 1] += correction\n  if cond_direction < 0:\n    layers = _reverse_second_list_dimension(layers)\n\n  return _stack_nd(layers, [main_dim, cond_dim])\n\n\ndef _project_partial_trapezoid(weights, lattice_sizes, trapezoid_trust,\n                               constraint_group):\n  \"\"\"Applies exact trapezoid trust projection to a subset of one constraint.\n\n  Algorithm details:\n\n  For the trapezoid trust projection, each trust constraint touches every\n  lattice vertex, but can be broken up into 2 independent sets of constraints,\n  based on whether the constraint's smaller index along the conditional\n  dimension is even or odd. That leaves us with 2t sets of constraints if we\n  have t trust constraints, which we can sequentially project onto with the\n  Dykstra algorithm.\n\n  This function applies to a single set of independent constraints within a\n  single trust constraint. The constraint group can take the value 0 or 1,\n  corresponding to even (0) or odd (1) for conditional dimension index.\n\n  * k trust constraints projection:\n  If we know how to project into single trust constraint then we can use\n  Dykstra algorithm to project into union of all k trust constraints.\n\n  * Single trust constraint projection:\n  Trapezoid constraints require the range of possible model outputs across the\n  main feature to be larger when the conditional feature demonstrates higher\n  trust in the main feature. That is, they constrain the 'extreme' (minimum and\n  maximum) weights in the main feature dimension but not any of the weights in\n  the middle if the lattice size is larger than 2. We therefore have one set of\n  constraints along the conditional dimension when the main feature is at its\n  minimum and one when the main feature is at its maximum. For example, consider\n  the grid\n\n  ```\n  0---1---2---3\n  |   |   |   |\n  4---5---6---7\n  |   |   |   |\n  8---9---10--11\n  |   |   |   |\n  12--13--14--15\n  ```\n\n  If the main feature is on the x-axis and the conditional feature is on the y-\n  axis in this grid, our constraints operate on {0,4,8,12} and {3,7,11,15}. In\n  fact, those constraints are simply monotonicity constraints in opposite\n  directions. If the cond_direction = 1, we are monotonically decreasing between\n  12 and 0 (0 < 4 < 8 < 12) and monotonically increasing between 15 and 3\n  (3 > 7 > 11 > 15). Note that these imply that [0,3] is a superset of [4,7] and\n  so on down to the smallest subset [12,15]. Our two independent sets of these\n  constraints match those for monotonicity based on even and odd indices. For\n  example, [8 < 12], [4 < 0], [11 > 15], and [3 > 7] can be projected onto at\n  once, while [4 < 8] and [7 > 11] are in the other group. All constraint\n  directions are flipped if cond_direction = -1.\n\n  * Individual weight projection:\n  For each pair of constraints, we project as in monotonicity: each weight moves\n  halfway towards each other if the constraint is being violated, and stays the\n  same otherwise.\n\n  Differs from _approximately_project_trapezoid in that this algorithm\n  - Only operates on the constraints for a single (main_dim, cond_dim) pair.\n  - Does not guarantee a satisfying solution to the full trust constraint.\n  - Exactly projects (in L2 terms) on the subset of constraints it does\n    operate on.\n\n  Args:\n    weights: Tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    trapezoid_trust: Three-element tuple representing a single trust constraint.\n      First element is the index of the main (monotonic) feature. Second element\n      is the index of the conditional feature. Third element is the direction of\n      trust set to 1 if higher values of the conditional feature increase trust\n      and -1 otherwise.\n    constraint_group: 0 or 1 as defined above, representing whether we are\n      acting on even or odd indices\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  main_dim, cond_dim, cond_direction = trapezoid_trust\n  layers = _unstack_nd(weights, [main_dim, cond_dim])\n\n  max_main_dim = lattice_sizes[main_dim] - 1\n  if cond_direction < 0:\n    layers = _reverse_second_list_dimension(layers)\n  for j in range(constraint_group, lattice_sizes[cond_dim] - 1, 2):\n    lhs_difference = layers[0][j + 1] - layers[0][j]\n    lhs_correction = tf.maximum(lhs_difference / 2, 0)\n    layers[0][j] += lhs_correction\n    layers[0][j + 1] -= lhs_correction\n\n    rhs_difference = layers[max_main_dim][j] - layers[max_main_dim][j + 1]\n    rhs_correction = tf.maximum(rhs_difference / 2, 0)\n    layers[max_main_dim][j] -= rhs_correction\n    layers[max_main_dim][j + 1] += rhs_correction\n  if cond_direction < 0:\n    layers = _reverse_second_list_dimension(layers)\n\n  return _stack_nd(layers, [main_dim, cond_dim])\n\n\ndef _project_partial_monotonic_dominance(weights, lattice_sizes,\n                                         monotonic_dominance, constraint_group):\n  r\"\"\"Applies exact monotonic dominance projection to given constraint group.\n\n  Algorithm details:\n\n  For the monotonic dominance projection, we follow a similar approach to the\n  monotonicity projection by splitting up the constraints into independent sets.\n  Here, each dominance constraint can be broken up into 8 independent sets of\n  constraints, based on (1) whether the constraint's smaller indices along the\n  dominant and weak dimensions are even or odd and (2) two triplets of vertices\n  to consider for each square in the grid shown below.\n\n  That leaves us with 8k sets of constraints if we have k dominance constraints,\n  which we can sequentially project onto with the Dykstra algorithm.\n\n  This function applies to a single set of independent constraints within a\n  single dominance constraint group. The constraint group can take the value\n  {0,1} x {0,1} x {0,1}. Even (0) or odd (1) of the first two elements\n  correspond to the dominant and weak features and the third element determines\n  which of the two triplets within a square to consider.\n\n  * k monotonic dominance constraints projection:\n  If we know how to project into single monotonic dominance constraint then we\n  can use Dykstra algorithm to project into union of all k dominance\n  constraints.\n\n  * Single monotonic dominance constraint projection\n  Monotonic dominance constraints require the effect (slope) in the direction\n  of the dominant dimension to be greater than that of the weak dimension for\n  any point in the lattice. We can think of this as separate constraints applied\n  to each 'triangle' of weights represented as either {(i,j,...), (i+1,j,...),\n  (i+1,j+1,...)} or {(i,j,...), (i,j+1,...), (i+1,j+1,...)} where i and j denote\n  the index dimensions of the dominant and weak features and the ellipses\n  represent a fixed value of the other feature dimensions. Considering then a\n  fixed slice, and a grid\n\n  ```\n  0---1---2---3\n  | \\ | \\ | \\ |\n  4---5---6---7\n  | \\ | \\ | \\ |\n  8---9---10--11\n  | \\ | \\ | \\ |\n  12--13--14--15\n  ```\n\n  where the dominant feature is on the x-axis and the weak feature is on the\n  y-axis, we get our 8 independent sets of non-overlapping triangular triplets\n  of vertices. For example, one set consists of {(0,1,4), (8,9,12), (2,3,6),\n  (10,11,14)}.\n\n  * Individual weight projection\n  Within each triangular triplet, the projection moves the weight of the right\n  angled vertex, either top-right or bottom-left, by 2 * violation / 3 and the\n  other two vertices by violation / 3 to satisfy the constraint while minimizing\n  the L2 distance from the initial point.\n\n  Args:\n    weights: tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: list or tuple of integers which represents lattice sizes\n      which correspond to weights.\n    monotonic_dominance: two-element tuple representing a single monotonic\n      dominance constraint. First element is the index of the dominant feature.\n      Second element is the index of the weak feature.\n    constraint_group: three-element tuple as defined above, representing 'even'\n      or 'odd' indices and which of the two triangles we are acting on.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  dominant_dim, weak_dim = monotonic_dominance\n  layers = _unstack_nd(weights, [dominant_dim, weak_dim])\n  for i in range(constraint_group[0], lattice_sizes[dominant_dim] - 1, 2):\n    for j in range(constraint_group[1], lattice_sizes[weak_dim] - 1, 2):\n      midpoint = (layers[i][j] + layers[i + 1][j + 1]) / 2\n      if constraint_group[2] == 1:\n        difference = midpoint - layers[i + 1][j]\n        correction = tf.maximum(difference / 3, 0)\n        layers[i + 1][j] += 2 * correction\n      else:\n        difference = midpoint - layers[i][j + 1]\n        correction = tf.minimum(difference / 3, 0)\n        layers[i][j + 1] += 2 * correction\n      layers[i][j] -= correction\n      layers[i + 1][j + 1] -= correction\n\n  return _stack_nd(layers, [dominant_dim, weak_dim])\n\n\ndef _project_partial_range_dominance(weights, lattice_sizes, range_dominance,\n                                     constraint_group):\n  r\"\"\"Applies exact range dominance projection to given constraint group.\n\n  Algorithm details:\n\n  For the range dominance projection, each range dominance constraint can be\n  broken up into M x N independent constraints where M and N are the lattice\n  sizes of the dominant and weak dimensions. In other words, there are vertex\n  number of constraints to project onto. This leaves us with M x N x k\n  constraints if we have k range dominance constraints, which we can\n  sequentially project onto with the Dykstra algorithm.\n\n  This function applies to a single independent constraint within a single range\n  dominance constraint as specificed by the given constraint group.\n\n  * k range dominance constraints projection:\n  If we know how to project into single range dominance constraint then we can\n  use Dykstra algorithm to project into union of all k dominance constraints.\n\n  * Single range dominance constraint projection:\n  Range dominance constraints require the range of possible outputs to be\n  greater if one varies the dominant dimension than if one varies the weak\n  dimension for any point. Considering then a fixed slice, and a grid\n\n  ```\n  0---1---2---3\n  |   |   |   |\n  4---5---6---7\n  |   |   |   |\n  8---9---10--11\n  |   |   |   |\n  12--13--14--15\n  ```\n\n  where the dominant dimension is on the x-axis and the weak dimension is on the\n  y-axis, we get, for each vertex defined by the x and y coordinates, a\n  constraint where the range for direction in x-axis is required to be greater\n  than the range for direction in y-axis. For example, vertex 1 requires its\n  dominant range defined by vertices 0 and 3 to be greater than its weak range\n  defined by vertices 1 and 13.\n\n  * Individual weight projection:\n  The projection moves the weights of all four vertices defining the dominant\n  and weak ranges by the constraint violation / 4 such that the dominant range\n  grows and the weak range shrinks. The only exception is the four corner\n  vertices, i.e. vertices 0, 3, 12, 15. In this case, there are three\n  participating vertices and since one of the vertices is shared by the two\n  conflicting ranges, we only move the weights of the other two vertices. This\n  means, for vertex 0, we move the weight of vertex 3 up halfway, the weight of\n  vertex 12 down halfway and leave the weight of vertex 0 unchanged.\n\n  Args:\n    weights: tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: list or tuple of integers which represents lattice sizes\n      which correspond to weights.\n    range_dominance: two-element tuple representing a single range dominance\n      constraint. First element is the index of the dominant feature. Second\n      element is the index of the weak feature.\n    constraint_group: two-element tuple as defined above, representing the\n      location of a vertex we are acting on.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  dom_dim, weak_dim = range_dominance\n  dom_dim_size = lattice_sizes[dom_dim]\n  weak_dim_size = lattice_sizes[weak_dim]\n  i, j = constraint_group\n  layers = _unstack_nd(weights, [dom_dim, weak_dim])\n  difference = ((layers[i][weak_dim_size - 1] - layers[i][0]) -\n                (layers[dom_dim_size - 1][j] - layers[0][j]))\n  if (i == 0 or i == dom_dim_size - 1) and (j == 0 or j == weak_dim_size - 1):\n    correction = tf.maximum(difference / 2, 0)\n    if i == 0:\n      layers[dom_dim_size - 1][j] += correction\n    else:\n      layers[0][j] -= correction\n    if j == 0:\n      layers[i][weak_dim_size - 1] -= correction\n    else:\n      layers[i][0] += correction\n  else:\n    correction = tf.maximum(difference / 4, 0)\n    layers[i][weak_dim_size - 1] -= correction\n    layers[i][0] += correction\n    layers[dom_dim_size - 1][j] += correction\n    layers[0][j] -= correction\n\n  return _stack_nd(layers, [dom_dim, weak_dim])\n\n\ndef _project_partial_joint_monotonicity(weights, lattice_sizes,\n                                        joint_monotonicity, constraint_group):\n  \"\"\"Applies exact joint monotonicity projection to given constraint group.\n\n  Algorithm details:\n\n  For the joint monotonicity projection, we follow a similar approach to the\n  per-dimension monotonicity projection by splitting up the constraints into\n  independent sets. Here, each joint monotonicity constraint can be broken up\n  into 8 independent sets of constraints, based on (1) whether the constraint's\n  smaller indices along the two given dimensions are even or odd and (2) two\n  triplets of vertices to consider for each square in the grid shown below.\n\n  That leaves us with 8k sets of constraints if we have k joint monotonocity\n  constraints, which we can sequentially project onto with the Dykstra\n  algorithm.\n\n  This function applies to a single set of independent constraints within a\n  single joint monotonicity constraint. The constraint group can take the value\n  {0,1} x {0,1} x {0,1}. Even (0) or odd (1) of the first two elements\n  correspond to the two features that are jointly monotonic and the third\n  element determines which of the two triplets within in a square to consider.\n\n  * k joint monotonicity constraints projection:\n  If we know how to project into single joint monotonicity constraint then we\n  can use Dykstra algorithm to project into union of all k joint monotonicity\n  constraints.\n\n  * Single joint monotonicity constraint projection\n  Joint monotonicity constraints require the function to be monotonic along a\n  diagonal direction of a two-feature subspace, ceteris paribus all other\n  features. The sum of the partial derivatives on the constraint features needs\n  to be non-negative. We can think of this as separate constraints applied to\n  each 'triangle' of weights represented as either {(i,j,...), (i+1,j,...),\n  (i,j+1,...)} or {(i+1,j+1,...), (i+1,1,...), (i,j+1,...)} where i  and j\n  denote the index dimensions of the two features and the ellipses represent a\n  fixed value of the other feature dimensions. Considering then a fixed slice,\n  and a grid\n\n  ```\n  0---1---2---3\n  | / | / | / |\n  4---5---6---7\n  | / | / | / |\n  8---9---10--11\n  | / | / | / |\n  12--13--14--15\n  ```\n\n  we get our 8 independent sets of non-overlapping triangular triplets of\n  vertices. For example, one set consists of {(0,1,4}, (8,9,12), (2,3,6),\n  (10,11,14)}.\n\n  * Individual weight projection\n  Within each triangular triplet, the projection moves the weight of the right\n  angled vertex, either top-left or bottom-right, by 2 * violation / 3 and the\n  other two vertices by violation / 3 to satisfy the constraint while minimizing\n  the L2 distance from the initial point.\n\n  Args:\n    weights: tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: list or tuple of integers which represents lattice sizes\n      which correspond to weights.\n    joint_monotonicity: two-element tuple representing a single joint\n      monotonicity constraint. The two elements are the index of the two\n      constrained features.\n    constraint_group: three-element tuple as defined above, representing the\n      combination of 'even' and 'odd' constraints we are projecting on.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n\n  dim1, dim2 = joint_monotonicity\n  layers = _unstack_nd(weights, [dim1, dim2])\n  for i in range(constraint_group[0], lattice_sizes[dim1] - 1, 2):\n    for j in range(constraint_group[1], lattice_sizes[dim2] - 1, 2):\n      midpoint = (layers[i + 1][j] + layers[i][j + 1]) / 2\n      if constraint_group[2] == 1:\n        difference = midpoint - layers[i + 1][j + 1]\n        correction = tf.maximum(difference / 3, 0)\n        layers[i + 1][j + 1] += 2 * correction\n      else:\n        difference = midpoint - layers[i][j]\n        correction = tf.minimum(difference / 3, 0)\n        layers[i][j] += 2 * correction\n      layers[i + 1][j] -= correction\n      layers[i][j + 1] -= correction\n\n  return _stack_nd(layers, [dim1, dim2])\n\n\ndef _project_partial_joint_unimodality(weights, lattice_sizes,\n                                       joint_unimodalities, vertex, offsets):\n  \"\"\"Applies exact joint unimodality projection to given constraint group.\n\n  Constraint group is represented by vertex and offsets. Vertex means vertex\n  of lattice for which directional derivatives are being computed. Offsets is a\n  list of {-1, 1} which represent which of hypercubes adjacent to vertex is\n  being processed.\n  Each pair of vertex and offsets results in linear equation involving\n  len(vertex) + 1 of constrained vertices. We project onto this equation being\n  positive or negative depending on whether we need peak or valley.\n\n  Args:\n    weights: tensor with weights of lattice layer, with shape lattice_sizes.\n    lattice_sizes: list or tuple of integers which represents lattice sizes\n      which correspond to weights.\n    joint_unimodalities: tuple representing single joint unimodality constraint.\n      Elements are the indices of constrained features followed by 'valley' or\n      'peak'.\n    vertex: len(joint_unimodalities)-1 dimensional lattice vertex from\n      dimensions specified by joint_unimodalities.\n    offsets: list of {-1, 1} which represents which of hypercubes adjacent to\n      vertex is being processed.\n\n  Returns:\n    None or tensor with projected weights matching shape of input weights. In\n    case of None pair: (vertex, offset) resulted into constraint group for\n    which no update to weights is needed.\n  \"\"\"\n  # This functoin builds hyperplane equation and then calls\n  # _project_onto_hyperplane() to project.\n  dimensions = joint_unimodalities[0]\n  if len(vertex) != len(dimensions):\n    raise ValueError(\"%s %s\" % (vertex, joint_unimodalities))\n\n  upper_bound = [lattice_sizes[dim] for dim in dimensions]\n  center = [size // 2 for size in upper_bound]\n\n  if all(v == c for v, c in zip(vertex, center)):\n    return None\n\n  equation = []\n  all_vertices = []\n\n  for dim, offset in enumerate(offsets):\n    dim_weight = vertex[dim] - center[dim]\n    if dim_weight == 0:\n      continue\n\n    neighbour = list(vertex)\n    neighbour[dim] += offset\n    if neighbour[dim] < 0 or neighbour[dim] >= upper_bound[dim]:\n      return None\n\n    all_vertices.append(neighbour)\n    equation.append(dim_weight * offset)\n\n  if not all_vertices:\n    return None\n\n  # Add 'vertex' iteself with corresponding weights.\n  all_vertices.append(list(vertex))\n  equation.append(-sum(equation))\n\n  return _project_onto_hyperplane(\n      weights=weights,\n      joint_unimodalities=joint_unimodalities,\n      hyperplane=equation,\n      vertices=all_vertices)\n\n\ndef _project_onto_hyperplane(weights, joint_unimodalities, hyperplane,\n                             vertices):\n  \"\"\"Projects onto hyperplane.\n\n  Args:\n    weights: tensor with weights of lattice layer, with shape lattice_sizes.\n    joint_unimodalities: tuple representing a single joint unimodality\n      constraint. Elements are the index of constrained features, followed by\n      'valley' or 'peak'.\n    hyperplane: list of coefficients of hyperplane onto which we project.\n    vertices: list of len(joint_unimodalities)-1 dimensional points of lenght\n      len(hyperplane) which correspond to coefficients of hyperplane. This\n      points will be used to extract elements from 'weights' which are related\n      to given hyperplane.\n\n  Returns:\n    Tensor with projected weights matching shape of input weights.\n  \"\"\"\n  hyperplane = tf.constant(hyperplane, dtype=weights.dtype)\n\n  # TODO: unstacking entire set of weights for the purpuse of projection\n  # onto single hyperplane is very inefficient for high number of jointly\n  # unimodal dims. Consider other options. So far I see 4 candidates:\n  # 1) Find a way to efficiently combine independent hyperplanes so we can\n  #    project onto several hyperplane at once. This would be correct projection\n  #    with respect to L2 norm, but headroom for this approach is limited\n  #    because for example for 4 constrained dims of size 3 (3^4) we have 81\n  #    different varialbes and 5 variables per equation. This gives us upper\n  #    bound of 81/5 = 16 times speed up. In reality it will probably be around\n  #    5-10 times.\n  # 2) Use tf.gather_nd() to gather affected weights instead of stacking and\n  #    unstacking. Hard to estimare how much of an improvement it will be.\n  # 3) Project onto all hyperplanes in a single step. This will violate Dykstra\n  #    algorithm, so projection will not be into nearest point because according\n  #    to Dykstra we need to porject into all dependent hyperplanes\n  #    consequently. But regardless it could work well enough in practice and\n  #    hopefully will be fast enough.\n  # 4) Come up with better option.\n  dimensions, direction = joint_unimodalities\n  layers = _unstack_nd(weights, dims=dimensions)\n  affected_weights = [\n      _get_element(lists=layers, indices=position) for position in vertices\n  ]\n\n  affected_weights = tf.stack(affected_weights, axis=-1)\n  violation = tf.reduce_sum(affected_weights * hyperplane, axis=-1)\n  if direction == \"valley\":\n    violation = tf.minimum(violation, 0.0)\n  else:\n    violation = tf.maximum(violation, 0.0)\n\n  correction_factor = violation / tf.reduce_sum(hyperplane * hyperplane)\n  correction = tf.expand_dims(correction_factor, axis=-1) * hyperplane\n  projection = affected_weights - correction\n\n  affected_weights = tf.unstack(projection, axis=-1)\n  for tensor, position in zip(affected_weights, vertices):\n    _set_element(lists=layers, indices=position, value=tensor)\n\n  return _stack_nd(layers, dims=dimensions)\n\n\n# TODO: Test whether adding min/max capping to dykstra projection would\n# improve performance.\ndef project_by_dykstra(weights,\n                       lattice_sizes,\n                       monotonicities=None,\n                       unimodalities=None,\n                       edgeworth_trusts=None,\n                       trapezoid_trusts=None,\n                       monotonic_dominances=None,\n                       range_dominances=None,\n                       joint_monotonicities=None,\n                       joint_unimodalities=None,\n                       num_iterations=1):\n  \"\"\"Applies dykstra's projection algorithm for monotonicity/trust constraints.\n\n  - Returns honest projection with respect to L2 norm if num_iterations is inf.\n  - Monotonicity will be violated by some small eps(num_iterations).\n  - Complexity: O(num_iterations * (num_monotonic_dims + num_trust_constraints)\n    * num_lattice_weights)\n\n  Dykstra's alternating projections algorithm projects into intersection of\n  several convex sets. For algorithm description itself use Google or Wiki:\n  https://en.wikipedia.org/wiki/Dykstra%27s_projection_algorithm\n\n  Here, each monotonicity constraint is split up into 2 independent convex sets\n  each trust constraint is split up into 4 independent convex sets. These sets\n  are then projected onto exactly (in L2 space). For more details, see the\n  _project_partial_* functions.\n\n  Args:\n    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.\n    lattice_sizes: list or tuple of integers which represents lattice sizes.\n      which correspond to weights.\n    monotonicities: None or list or tuple of same length as lattice_sizes of {0,\n      1} which represents monotonicity constraints per dimension. 1 stands for\n      increasing (non-decreasing in fact), 0 for no monotonicity constraints.\n    unimodalities: None or list or tuple of same length as lattice_sizes of {-1,\n      0, 1} which represents unimodality constraints per dimension. 1 indicates\n      that function first decreases then increases, -1 indicates that function\n      first increases then decreases, 0 indicates no unimodality constraints.\n    edgeworth_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust: 1 if\n        higher values of the conditional feature should increase trust in the\n        main feature and -1 otherwise.\n    trapezoid_trusts: None or iterable of three-element tuples. First element is\n      the index of the main (monotonic) feature. Second element is the index of\n      the conditional feature. Third element is the direction of trust: 1 if\n        higher values of the conditional feature should increase trust in the\n        main feature and -1 otherwise.\n    monotonic_dominances: None or iterable of two-element tuples. First element\n      is the index of the dominant feature. Second element is the index of the\n      weak feature.\n    range_dominances: None or iterable of two-element tuples. First element is\n      the index of the dominant feature. Second element is the index of the weak\n      feature.\n    joint_monotonicities: None or iterable of two-element tuples. Each tuple\n      represents a pair of feature indices that require joint monotoniticity.\n    joint_unimodalities: None or tuple or iterable of tuples. Each tuple\n      represents indices of single group of jointly unimodal features followed\n      by 'valley' or 'peak'.\n    num_iterations: number of iterations of Dykstra's algorithm.\n\n  Returns:\n    Projected weights tensor of same shape as `weights`.\n  \"\"\"\n  if num_iterations == 0:\n    return weights\n  if (utils.count_non_zeros(monotonicities, unimodalities) == 0 and\n      not joint_monotonicities and not joint_unimodalities and\n      not range_dominances):\n    return weights\n\n  units = weights.shape[1]\n  if monotonicities is None:\n    monotonicities = [0] * len(lattice_sizes)\n  if unimodalities is None:\n    unimodalities = [0] * len(lattice_sizes)\n  if edgeworth_trusts is None:\n    edgeworth_trusts = []\n  if trapezoid_trusts is None:\n    trapezoid_trusts = []\n  if monotonic_dominances is None:\n    monotonic_dominances = []\n  if range_dominances is None:\n    range_dominances = []\n  if joint_monotonicities is None:\n    joint_monotonicities = []\n  if joint_unimodalities is None:\n    joint_unimodalities = []\n  if units > 1:\n    lattice_sizes = lattice_sizes + [int(units)]\n    monotonicities = monotonicities + [0]\n    unimodalities = unimodalities + [0]\n\n  weights = tf.reshape(weights, lattice_sizes)\n\n  def body(iteration, weights, last_change):\n    \"\"\"Body of the tf.while_loop for Dykstra's projection algorithm.\n\n    This implements Dykstra's projection algorithm and requires rolling back\n    the last projection change.\n\n    Args:\n      iteration: Iteration counter tensor.\n      weights: Tensor with project weights at each iteraiton.\n      last_change: Dict that stores the last change in the weights after\n        projecting onto the each subset of constraints.\n\n    Returns:\n      The tuple (iteration, weights, last_change) at the end of each iteration.\n    \"\"\"\n    last_change = copy.copy(last_change)\n    for dim in range(len(lattice_sizes)):\n      if monotonicities[dim] == 0 and unimodalities[dim] == 0:\n        continue\n\n      for constraint_group in [0, 1]:\n        # Iterate over 2 sets of constraints per dimension: even and odd.\n        # Odd set exists only when there are more than 2 lattice vertices.\n        if constraint_group + 1 >= lattice_sizes[dim]:\n          continue\n\n        # Rolling back last projection into current set as required by Dykstra's\n        # algorithm.\n        rolled_back_weights = weights - last_change[(\"MONOTONICITY\", dim,\n                                                     constraint_group)]\n        weights = _project_partial_monotonicity(rolled_back_weights,\n                                                lattice_sizes, monotonicities,\n                                                unimodalities, dim,\n                                                constraint_group)\n        last_change[(\"MONOTONICITY\", dim,\n                     constraint_group)] = weights - rolled_back_weights\n\n    for constraint in edgeworth_trusts:\n      main_dim, cond_dim, _ = constraint\n      for constraint_group in [(0, 0), (0, 1), (1, 0), (1, 1)]:\n        if (constraint_group[0] >= lattice_sizes[main_dim] - 1 or\n            constraint_group[1] >= lattice_sizes[cond_dim] - 1):\n          continue\n\n        rolled_back_weights = (\n            weights - last_change[(\"EDGEWORTH\", constraint, constraint_group)])\n        weights = _project_partial_edgeworth(rolled_back_weights, lattice_sizes,\n                                             constraint, constraint_group)\n        last_change[(\"EDGEWORTH\", constraint,\n                     constraint_group)] = weights - rolled_back_weights\n\n    for constraint in trapezoid_trusts:\n      _, cond_dim, _ = constraint\n      for constraint_group in [0, 1]:\n        if constraint_group >= lattice_sizes[cond_dim] - 1:\n          continue\n\n        rolled_back_weights = (\n            weights - last_change[(\"TRAPEZOID\", constraint, constraint_group)])\n        weights = _project_partial_trapezoid(rolled_back_weights, lattice_sizes,\n                                             constraint, constraint_group)\n        last_change[(\"TRAPEZOID\", constraint,\n                     constraint_group)] = weights - rolled_back_weights\n\n    for constraint in monotonic_dominances:\n      dominant_dim, weak_dim = constraint\n      for constraint_group in itertools.product([0, 1], [0, 1], [0, 1]):\n        if (constraint_group[0] >= lattice_sizes[dominant_dim] - 1 or\n            constraint_group[1] >= lattice_sizes[weak_dim] - 1):\n          continue\n\n        rolled_back_weights = weights - last_change[\n            (\"MONOTONIC_DOMINANCE\", constraint, constraint_group)]\n        weights = _project_partial_monotonic_dominance(rolled_back_weights,\n                                                       lattice_sizes,\n                                                       constraint,\n                                                       constraint_group)\n        last_change[(\"MONOTONIC_DOMINANCE\", constraint,\n                     constraint_group)] = weights - rolled_back_weights\n\n    for constraint in range_dominances:\n      dominant_dim, weak_dim = constraint\n      dom_dim_idx = range(lattice_sizes[dominant_dim])\n      weak_dim_idx = range(lattice_sizes[weak_dim])\n      for constraint_group in itertools.product(dom_dim_idx, weak_dim_idx):\n        rolled_back_weights = weights - last_change[\n            (\"RANGE_DOMINANCE\", constraint, constraint_group)]\n        weights = _project_partial_range_dominance(rolled_back_weights,\n                                                   lattice_sizes, constraint,\n                                                   constraint_group)\n        last_change[(\"RANGE_DOMINANCE\", constraint,\n                     constraint_group)] = weights - rolled_back_weights\n\n    for constraint in joint_monotonicities:\n      dim1, dim2 = constraint\n      for constraint_group in itertools.product([0, 1], [0, 1], [0, 1]):\n        if (constraint_group[0] >= lattice_sizes[dim1] - 1 or\n            constraint_group[1] >= lattice_sizes[dim2] - 1):\n          continue\n\n        rolled_back_weights = weights - last_change[\n            (\"JOINT_MONOTONICITY\", constraint, constraint_group)]\n        weights = _project_partial_joint_monotonicity(rolled_back_weights,\n                                                      lattice_sizes, constraint,\n                                                      constraint_group)\n        last_change[(\"JOINT_MONOTONICITY\", constraint,\n                     constraint_group)] = weights - rolled_back_weights\n\n    for constraint in joint_unimodalities:\n      dimensions = tuple(constraint[0])\n      lattice_ranges = [range(lattice_sizes[dim]) for dim in dimensions]\n      for vertex in itertools.product(*lattice_ranges):\n        for offsets in itertools.product([-1, 1], repeat=len(dimensions)):\n          # For this projection constraint group is represented by pair: vertex,\n          # offsets.\n          projection_key = (\"JOINT_UNIMODALITY\", dimensions, vertex, offsets)\n          if projection_key in last_change:\n            rolled_back_weights = weights - last_change[projection_key]\n          else:\n            rolled_back_weights = weights\n          projected_weights = _project_partial_joint_unimodality(\n              weights=rolled_back_weights,\n              lattice_sizes=lattice_sizes,\n              joint_unimodalities=constraint,\n              vertex=vertex,\n              offsets=offsets)\n          if projected_weights is not None:\n            weights = projected_weights\n            last_change[projection_key] = weights - rolled_back_weights\n    return iteration + 1, weights, last_change\n\n  def cond(iteration, weights, last_change):\n    del weights, last_change\n    return tf.less(iteration, num_iterations)\n\n  # Run the body of the loop once to find required last_change keys. The set of\n  # keys in the input and output of the body of tf.while_loop must be the same.\n  # The resulting ops are discarded and will not be part of the TF graph.\n  zeros = tf.zeros(shape=lattice_sizes, dtype=weights.dtype)\n  last_change = collections.defaultdict(lambda: zeros)\n  (_, _, last_change) = body(0, weights, last_change)\n\n  # Apply Dykstra's algorithm with tf.while_loop.\n  iteration = tf.constant(0)\n  last_change = {k: zeros for k in last_change}\n  (_, weights, _) = tf.while_loop(cond, body, (iteration, weights, last_change))\n  return tf.reshape(weights, shape=[-1, units])\n\n\ndef laplacian_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):\n  \"\"\"Returns Laplacian regularization loss for `Lattice` layer.\n\n  Laplacian regularizer penalizes the difference between adjacent vertices in\n  multi-cell lattice (see\n  [publication](http://jmlr.org/papers/v17/15-243.html)).\n\n  Consider a 3 x 2 lattice with weights `w`:\n\n  ```\n  w[3]-----w[4]-----w[5]\n    |        |        |\n    |        |        |\n  w[0]-----w[1]-----w[2]\n  ```\n\n  where the number at each node represents the weight index.\n  In this case, the laplacian regularizer is defined as:\n\n  ```\n  l1[0] * (|w[1] - w[0]| + |w[2] - w[1]| +\n           |w[4] - w[3]| + |w[5] - w[4]|) +\n  l1[1] * (|w[3] - w[0]| + |w[4] - w[1]| + |w[5] - w[2]|) +\n\n  l2[0] * ((w[1] - w[0])^2 + (w[2] - w[1])^2 +\n           (w[4] - w[3])^2 + (w[5] - w[4])^2) +\n  l2[1] * ((w[3] - w[0])^2 + (w[4] - w[1])^2 + (w[5] - w[2])^2)\n  ```\n\n  Arguments:\n    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n    l1: l1 regularization amount. Either single float or list or tuple of floats\n      to specify different regularization amount per dimension.\n    l2: l2 regularization amount. Either single float or list or tuple of floats\n      to specify different regularization amount per dimension.\n\n  Returns:\n    Laplacian regularization loss.\n  \"\"\"\n  if not l1 and not l2:\n    return 0.0\n\n  rank = len(lattice_sizes)\n  # If regularization amount is given as single float assume same amount for\n  # every dimension.\n  if l1 and not isinstance(l1, (list, tuple)):\n    l1 = [l1] * rank\n  if l2 and not isinstance(l2, (list, tuple)):\n    l2 = [l2] * rank\n\n  if weights.shape[1] > 1:\n    lattice_sizes = lattice_sizes + [int(weights.shape[1])]\n    rank += 1\n    if l1:\n      l1 = l1 + [0.0]\n    if l2:\n      l2 = l2 + [0.0]\n  weights = tf.reshape(weights, shape=lattice_sizes)\n\n  result = tf.constant(0.0, shape=[], dtype=weights.dtype)\n  for dim in range(rank):\n    if (not l1 or not l1[dim]) and (not l2 or not l2[dim]):\n      continue\n    if dim > 0:\n      # Transpose so current dimension becomes first one in order to simplify\n      # indexing and be able to merge all other dimensions into 1 for better TPU\n      # performance.\n      permut = [p for p in range(rank)]\n      permut[0], permut[dim] = permut[dim], permut[0]\n      slices = tf.transpose(weights, perm=permut)\n    else:\n      slices = weights\n    slices = tf.reshape(slices, shape=[lattice_sizes[dim], -1])\n\n    diff = slices[1:] - slices[0:-1]\n    if l1:\n      result += tf.reduce_sum(tf.abs(diff)) * l1[dim]\n    if l2:\n      result += tf.reduce_sum(tf.square(diff)) * l2[dim]\n  return result\n\n\ndef torsion_regularizer(weights, lattice_sizes, l1=0.0, l2=0.0):\n  \"\"\"Returns Torsion regularization loss for `Lattice` layer.\n\n  Lattice torsion regularizer penalizes how much the lattice function twists\n  from side-to-side (see\n  [publication](http://jmlr.org/papers/v17/15-243.html)).\n\n  Consider a 3 x 2 lattice with weights `w`:\n\n  ```\n  w[3]-----w[4]-----w[5]\n    |        |        |\n    |        |        |\n  w[0]-----w[1]-----w[2]\n  ```\n\n  In this case, the torsion regularizer is defined as:\n\n  ```\n  l1 * (|w[4] + w[0] - w[3] - w[1]| + |w[5] + w[1] - w[4] - w[2]|) +\n  l2 * ((w[4] + w[0] - w[3] - w[1])^2 + (w[5] + w[1] - w[4] - w[2])^2)\n  ```\n\n  Arguments:\n    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n    l1: l1 regularization amount. Either single float or list or tuple of floats\n      to specify different regularization amount per dimension.\n    l2: l2 regularization amount. Either single float or list or tuple of floats\n      to specify different regularization amount per dimension. The amount for\n      the interaction term between i and j is the corresponding product of each\n      per feature amount.\n\n  Returns:\n    Laplacian regularization loss.\n  \"\"\"\n  rank = len(lattice_sizes)\n  if rank == 1 or (not l1 and not l2):\n    return 0.0\n\n  # If regularization amount is given as single float assume same amount for\n  # every dimension.\n  if l1 and not isinstance(l1, (list, tuple)):\n    l1 = [math.sqrt(l1)] * rank\n  if l2 and not isinstance(l2, (list, tuple)):\n    l2 = [math.sqrt(l2)] * rank\n\n  if weights.shape[1] > 1:\n    lattice_sizes = lattice_sizes + [int(weights.shape[1])]\n    rank += 1\n    if l1:\n      l1 = l1 + [0.0]\n    if l2:\n      l2 = l2 + [0.0]\n  weights = tf.reshape(weights, shape=lattice_sizes)\n\n  result = tf.constant(0.0, shape=[], dtype=weights.dtype)\n  for i in range(rank - 1):\n    for j in range(i + 1, rank):\n      if ((not l1 or not l1[i] or not l1[j]) and\n          (not l2 or not l2[i] or not l2[j])):\n        continue\n      if j == 1:\n        planes = weights\n      else:\n        # Transpose so dimensions i and j become first in order to simplify\n        # indexing and be able to merge all other dimensions into 1 for better\n        # TPU performance.\n        permut = [p for p in range(rank)]\n        permut[0], permut[i] = permut[i], permut[0]\n        permut[1], permut[j] = permut[j], permut[1]\n        planes = tf.transpose(weights, perm=permut)\n      planes = tf.reshape(\n          planes, shape=[lattice_sizes[i], lattice_sizes[j], -1])\n\n      a00 = planes[0:-1, 0:-1]\n      a01 = planes[0:-1, 1:]\n      a10 = planes[1:, 0:-1]\n      a11 = planes[1:, 1:]\n      torsion = a00 + a11 - a01 - a10\n\n      if l1:\n        result += tf.reduce_sum(tf.abs(torsion)) * l1[i] * l1[j]\n      if l2:\n        result += tf.reduce_sum(tf.square(torsion)) * l2[i] * l2[j]\n  return result\n\n\ndef _verify_dominances_hyperparameters(dominances, dominance_type,\n                                       monotonicities, num_input_dims):\n  \"\"\"Verifies that dominances hyperparameters are consistent.\n\n  Args:\n    dominances: Dominances hyperparameters of `Lattice` layer.\n    dominance_type: Type of dominance constraints which is either 'monotonic' or\n      'range'.\n    monotonicities: Monotonicities hyperparameter of `Lattice` layer.\n    num_input_dims: Number of input dimensions.\n\n  Raises:\n    ValueError: If something is inconsistent.\n  \"\"\"\n  assert dominance_type in (\"monotonic\", \"range\")\n  dim_pairs = set()\n  for constraint in dominances:\n    if len(constraint) != 2:\n      raise ValueError(\"%s dominance constraints must consist of 2 elements. \"\n                       \"Seeing constraint tuple %s\" %\n                       (dominance_type.capitalize(), constraint))\n    dominant_dim, weak_dim = constraint\n    if (dominant_dim >= num_input_dims or weak_dim >= num_input_dims or\n        dominant_dim < 0 or weak_dim < 0):\n      raise ValueError(\"Dimensions constrained by %s dominance constraints are \"\n                       \"not within the range of the lattice. 'dims': %s, %s, \"\n                       \"num_dims: %s\" %\n                       (dominance_type, dominant_dim, weak_dim, num_input_dims))\n    if not isinstance(dominant_dim, int) or not isinstance(weak_dim, int):\n      raise ValueError(\"%s dominance constraint dimensions must be integers. \"\n                       \"Seeing dominant_dim %s and weak_dim %s\" %\n                       (dominance_type.capitalize(), dominant_dim, weak_dim))\n    for dim in [dominant_dim, weak_dim]:\n      if monotonicities[dim] != 1:\n        raise ValueError(\"%s dominance constraint's dimensions must be \"\n                         \"monotonic. Dimension %d is not monotonic.\" %\n                         (dominance_type.capitalize(), dim))\n    # TODO: Determine partial ordering of features by dominance and\n    # detect any inconsistencies.\n    if (weak_dim, dominant_dim) in dim_pairs:\n      raise ValueError(\"Cannot have two %s dominance constraints on the same \"\n                       \"pair of features conflicting. Features: %d, %d\" %\n                       (dominance_type, dominant_dim, weak_dim))\n    dim_pairs.add((dominant_dim, weak_dim))\n\n\ndef verify_hyperparameters(lattice_sizes,\n                           units=None,\n                           weights_shape=None,\n                           input_shape=None,\n                           monotonicities=None,\n                           unimodalities=None,\n                           edgeworth_trusts=None,\n                           trapezoid_trusts=None,\n                           monotonic_dominances=None,\n                           range_dominances=None,\n                           joint_monotonicities=None,\n                           joint_unimodalities=None,\n                           output_min=None,\n                           output_max=None,\n                           regularization_amount=None,\n                           regularization_info=\"\",\n                           interpolation=\"hypercube\"):\n  \"\"\"Verifies that all given hyperparameters are consistent.\n\n  This function does not inspect weights themselves. Only their shape. Use\n  `assert_constraints()` to assert actual weights against constraints.\n\n  See `tfl.layers.Lattice` class level comment for detailed description of\n  arguments.\n\n  Args:\n    lattice_sizes: Lattice sizes to check againts.\n    units: Units hyperparameter of `Lattice` layer.\n    weights_shape: Shape of tensor which represents `Lattice` layer weights.\n    input_shape: Shape of layer input. Useful only if `units` is set.\n    monotonicities: Monotonicities hyperparameter of `Lattice` layer.\n    unimodalities: Unimodalities hyperparameter of `Lattice` layer.\n    edgeworth_trusts: Edgeworth_trusts hyperparameter of `Lattice` layer.\n    trapezoid_trusts: Trapezoid_trusts hyperparameter of `Lattice` layer.\n    monotonic_dominances: Monotonic dominances hyperparameter of `Lattice`\n      layer.\n    range_dominances: Range dominances hyperparameter of `Lattice` layer.\n    joint_monotonicities: Joint monotonicities hyperparameter of `Lattice`\n      layer.\n    joint_unimodalities: Joint unimodalities hyperparameter of `Lattice` layer.\n    output_min: Minimum output of `Lattice` layer.\n    output_max: Maximum output of `Lattice` layer.\n    regularization_amount: Regularization amount for regularizers.\n    regularization_info: String which describes `regularization_amount`.\n    interpolation: One of 'simplex' or 'hypercube' interpolation.\n\n  Raises:\n    ValueError: If something is inconsistent.\n  \"\"\"\n  for size in lattice_sizes:\n    if size < 2:\n      raise ValueError(\"All lattice sizes must be at least 2. Given: %s\" %\n                       lattice_sizes)\n\n  # It also raises errors if monotonicities specified incorrectly.\n  monotonicities = utils.canonicalize_monotonicities(\n      monotonicities, allow_decreasing=False)\n  if monotonicities is not None:\n    if len(monotonicities) != len(lattice_sizes):\n      raise ValueError(\"If provided 'monotonicities' should have same number \"\n                       \"of elements as 'lattice_sizes'. 'monotonicities': %s,\"\n                       \"'lattice_sizes: %s\" % (monotonicities, lattice_sizes))\n\n  unimodalities = utils.canonicalize_unimodalities(unimodalities)\n  if unimodalities is not None:\n    if len(unimodalities) != len(lattice_sizes):\n      raise ValueError(\"If provided 'unimodalities' should have same number \"\n                       \"of elements as 'lattice_sizes'. 'unimodalities': %s, \"\n                       \"'lattice_sizes: %s\" % (unimodalities, lattice_sizes))\n    for unimodality, dim_size in zip(unimodalities, lattice_sizes):\n      if unimodality != 0 and dim_size < 3:\n        raise ValueError(\"Unimodal dimensions must have lattice size at \"\n                         \"least 3. unimodalities: %s, lattice_sizes: %s\" %\n                         (unimodalities, lattice_sizes))\n\n  if monotonicities is not None and unimodalities is not None:\n    for i, (monotonicity,\n            unimodality) in enumerate(zip(monotonicities, unimodalities)):\n      if monotonicity != 0 and unimodality != 0:\n        raise ValueError(\"Both monotonicity and unimodality can not be set \"\n                         \"simultaniously for same dimension. Dimension: %d, \"\n                         \"'monotonicities': %s, 'unimodalities': %s\" %\n                         (i, monotonicities, unimodalities))\n\n  all_trusts = utils.canonicalize_trust((edgeworth_trusts or []) +\n                                        (trapezoid_trusts or [])) or []\n  main_dims, cond_dims, trapezoid_cond_dims = set(), set(), set()\n  dim_pairs_direction = {}\n  for i, constraint in enumerate(all_trusts):\n    main_dim, cond_dim, cond_direction = constraint\n    if (main_dim >= len(lattice_sizes) or cond_dim >= len(lattice_sizes) or\n        main_dim < 0 or cond_dim < 0):\n      raise ValueError(\"Dimensions constrained by trust constraints \"\n                       \"are not within the range of the lattice. \"\n                       \"'trust_dims': %s, %s, num_dims: %s\" %\n                       (main_dim, cond_dim, len(lattice_sizes)))\n    if not isinstance(main_dim, int) or not isinstance(cond_dim, int):\n      raise ValueError(\"Trust constraint dimensions must be integers. Seeing \"\n                       \"main_dim %s and cond_dim %s\" % (main_dim, cond_dim))\n    if monotonicities[main_dim] != 1:\n      raise ValueError(\"Trust constraint's main feature must be \"\n                       \"monotonic. Dimension %s is not monotonic.\" % (main_dim))\n    if (main_dim, cond_dim) in dim_pairs_direction and dim_pairs_direction[\n        (main_dim, cond_dim)] != cond_direction:\n      raise ValueError(\"Cannot have two trust constraints on the same pair of \"\n                       \"features in opposite directions. Features: %d, %d\" %\n                       (main_dim, cond_dim))\n    # Only apply this check to trapezoid constraints when there are also\n    # edgeworth constraints.\n    if edgeworth_trusts and i >= len(edgeworth_trusts):\n      if cond_dim in trapezoid_cond_dims:\n        logging.warning(\n            \"Conditional dimension %d is being used in multiple trapezoid \"\n            \"trust constraints. Because of this and the presence of edgeworth \"\n            \"constraints, there may be slight trust violations of one or more \"\n            \"of these constraints at the end of training. Consider increasing \"\n            \"num_projection_iterations to reduce violation.\", cond_dim)\n      trapezoid_cond_dims.add(cond_dim)\n    main_dims.add(main_dim)\n    cond_dims.add(cond_dim)\n    dim_pairs_direction[(main_dim, cond_dim)] = cond_direction\n  main_and_cond = main_dims.intersection(cond_dims)\n  if main_and_cond:\n    raise ValueError(\"A feature cannot be both a main feature and a \"\n                     \"conditional feature in trust constraints. \"\n                     \"Seeing dimension %d in both\" % (main_and_cond.pop()))\n\n  if monotonic_dominances is not None:\n    _verify_dominances_hyperparameters(monotonic_dominances, \"monotonic\",\n                                       monotonicities, len(lattice_sizes))\n  if range_dominances is not None:\n    _verify_dominances_hyperparameters(range_dominances, \"range\",\n                                       monotonicities, len(lattice_sizes))\n\n  if joint_monotonicities is not None:\n    for i, constraint in enumerate(joint_monotonicities):\n      if len(constraint) != 2:\n        raise ValueError(\"Joint monotonicities constraints must consist of 2 \"\n                         \"elements. Seeing constraint tuple %s\" % (constraint,))\n      dim1, dim2 = constraint\n      if (dim1 >= len(lattice_sizes) or dim2 >= len(lattice_sizes) or\n          dim1 < 0 or dim2 < 0):\n        raise ValueError(\"Dimensions constrained by joint monotonicity \"\n                         \"constraints are not within the range of the lattice. \"\n                         \"'dims': %s, %s, num_dims: %s\" %\n                         (dim1, dim2, len(lattice_sizes)))\n      if not isinstance(dim1, int) or not isinstance(dim2, int):\n        raise ValueError(\"Joint monotonicity constraint dimensions must be \"\n                         \"integers. Seeing dimensions %s, %s\" % (dim1, dim2))\n\n  if joint_unimodalities is not None:\n    for single_constraint in joint_unimodalities:\n      dimensions, direction = single_constraint\n      if (not isinstance(direction, six.string_types) or\n          (direction.lower() != \"valley\" and direction.lower() != \"peak\")):\n        raise ValueError(\"Joint unimodality tuple must end with string 'valley'\"\n                         \" or 'peak' which represents unimodality direction. \"\n                         \"Given: %s\" % (single_constraint,))\n      for dim in dimensions:\n        if dim < 0 or dim >= len(lattice_sizes):\n          raise ValueError(\"Dimension constrained by joint unimodality is not \"\n                           \"within the range of the lattice. Joint unimodality \"\n                           \"dimension: %s, total number of dimensions: \"\n                           \"%s\" % (dim, len(lattice_sizes)))\n        if not isinstance(dim, int):\n          raise ValueError(\"Joint unimodality constraint dimensions must be \"\n                           \"integer. Seeing: %s\" % dim)\n        if lattice_sizes[dim] < 3:\n          raise ValueError(\"Dimensions constrained for joint unimodality must \"\n                           \"have lattice size at least 3. \"\n                           \"Dim: %s has size: %s\" % (dim, lattice_sizes[dim]))\n        if monotonicities and monotonicities[dim] != 0:\n          raise ValueError(\"Dimension %d constrained for joint_unimodalities \"\n                           \"can not also by monotonic.\" % dim)\n      dims_set = set(dimensions)\n      if len(dims_set) != len(dimensions):\n        raise ValueError(\"All dimensions within single joint unimodality \"\n                         \"constraint must be distinct. \"\n                         \"Given: %s\" % single_constraint)\n\n  if weights_shape is not None:\n    if len(weights_shape) != 2:\n      raise ValueError(\"Weights must have shape of rank-2. \"\n                       \"Given: %s\" % weights_shape)\n    expected_num_weights = 1\n    for dim_size in lattice_sizes:\n      expected_num_weights *= dim_size\n    if weights_shape[0] != expected_num_weights:\n      raise ValueError(\"Number of elements in weights does not correspond to \"\n                       \"lattice sizes. Weights shape: %s, lattice sizes: %s, \"\n                       \"Number of elements defined by lattice sizes: %d\" %\n                       (weights_shape, lattice_sizes, expected_num_weights))\n\n  if input_shape is not None:\n    if isinstance(input_shape[-1], int):\n      if input_shape[-1] != len(lattice_sizes):\n        raise ValueError(\"Last dimension of input shape must have same number \"\n                         \"of elements as 'lattice_sizes'. 'input shape': %s, \"\n                         \"'lattice_sizes': %s\" % (input_shape, lattice_sizes))\n      shape = input_shape\n    else:\n      if len(input_shape) != len(lattice_sizes):\n        raise ValueError(\"If lattice input is provided as list of tensors their\"\n                         \" number must match lattice_sizes. 'input list': %s, \"\n                         \"'lattice_sizes': %s\" % (input_shape, lattice_sizes))\n      shape = input_shape[0]\n    if units is not None:  # FYI: It is inside \"if input_shape is not None:\"\n      if units > 1 and (len(shape) < 3 or shape[-2] != units):\n        raise ValueError(\"If 'units' > 1 then input shape of Lattice layer must\"\n                         \" have rank at least 3 where second from last \"\n                         \"dimension is equal to 'units'. 'units': %s, \"\n                         \"input_shape: %s\" % (units, input_shape))\n\n  if output_min is not None and output_max is not None:\n    if output_min >= output_max:\n      raise ValueError(\"'output_min' must be not greater than 'output_max'. \"\n                       \"'output_min': %f, 'output_max': %f\" %\n                       (output_min, output_max))\n\n  if regularization_amount and isinstance(regularization_amount, (list, tuple)):\n    if len(regularization_amount) != len(lattice_sizes):\n      raise ValueError(\n          \"If %s losses are given per dimension their number must \"\n          \"match number of dimensions defined by lattice sizes. \"\n          \"l1: %s, lattice sizes: %s\" %\n          (regularization_info, regularization_amount, lattice_sizes))\n\n  if interpolation not in [\"hypercube\", \"simplex\"]:\n    raise ValueError(\"Lattice interpolation type should be either 'simplex' \"\n                     \"or 'hypercube': %s\" % interpolation)\n\n\n# TODO: investigate whether eps should be bigger.\ndef assert_constraints(weights,\n                       lattice_sizes,\n                       monotonicities,\n                       edgeworth_trusts,\n                       trapezoid_trusts,\n                       monotonic_dominances,\n                       range_dominances,\n                       joint_monotonicities,\n                       joint_unimodalities,\n                       output_min=None,\n                       output_max=None,\n                       eps=1e-6):\n  \"\"\"Asserts that weights satisfy constraints.\n\n  Args:\n    weights: `Lattice` weights tensor of shape: `(prod(lattice_sizes), units)`.\n    lattice_sizes: List or tuple of integers which represents lattice sizes.\n    monotonicities: Monotonicity constraints.\n    edgeworth_trusts: Edgeworth trust constraints.\n    trapezoid_trusts: Trapezoid trust constraints.\n    monotonic_dominances: Monotonic dominance constraints.\n    range_dominances: Range dominance constraints.\n    joint_monotonicities: Joint monotonicity constraints.\n    joint_unimodalities: Joint unimodality constraints.\n    output_min: None or lower bound constraints.\n    output_max: None or upper bound constraints.\n    eps: Allowed constraints violation.\n\n  Returns:\n    List of assertion ops in graph mode or directly executes assertions in eager\n    mode.\n  \"\"\"\n  # TODO: actually assert them.\n  del joint_unimodalities\n\n  if weights.shape[1] > 1:\n    lattice_sizes = lattice_sizes + [int(weights.shape[1])]\n    if monotonicities:\n      monotonicities = monotonicities + [0]\n  weights = tf.reshape(weights, shape=lattice_sizes)\n  asserts = []\n\n  for i in range(len(monotonicities or [])):\n    if monotonicities[i] != 1:\n      continue\n    weights_layers = tf.unstack(weights, axis=i)\n\n    for j in range(1, len(weights_layers)):\n      diff = tf.reduce_min(weights_layers[j] - weights_layers[j - 1])\n      asserts.append(\n          tf.Assert(\n              diff >= -eps,\n              data=[\n                  \"Monotonicity violation\", \"Feature index:\", i,\n                  \"Min monotonicity diff:\", diff, \"Upper layer number:\", j,\n                  \"Epsilon:\", eps, \"Layers:\", weights_layers[j],\n                  weights_layers[j - 1]\n              ]))\n\n  for main_dim, cond_dim, cond_direction in edgeworth_trusts or []:\n    weights_layers = _unstack_nd(weights, [main_dim, cond_dim])\n    for i in range(lattice_sizes[main_dim] - 1):\n      for j in range(lattice_sizes[cond_dim] - 1):\n        diff = tf.reduce_min(\n            cond_direction *\n            ((weights_layers[i + 1][j + 1] - weights_layers[i][j + 1]) -\n             (weights_layers[i + 1][j] - weights_layers[i][j])))\n        asserts.append(\n            tf.Assert(\n                diff >= -eps,\n                data=[\n                    \"Edgeworth trust violation\", \"Feature indices:\", main_dim,\n                    \",\", cond_dim, \"Min trust diff:\", diff, \"Epsilon:\", eps,\n                    \"Layers:\", weights_layers[i + 1][j + 1],\n                    weights_layers[i][j + 1], weights_layers[i + 1][j],\n                    weights_layers[i][j]\n                ]))\n\n  for main_dim, cond_dim, cond_direction in trapezoid_trusts or []:\n    weights_layers = _unstack_nd(weights, [main_dim, cond_dim])\n    max_main_dim = lattice_sizes[main_dim] - 1\n    for j in range(lattice_sizes[cond_dim] - 1):\n      lhs_diff = tf.reduce_min(\n          cond_direction * (weights_layers[0][j] - weights_layers[0][j + 1]))\n      asserts.append(\n          tf.Assert(\n              lhs_diff >= -eps,\n              data=[\n                  \"Trapezoid trust violation\", \"Feature indices:\", main_dim,\n                  \",\", cond_dim, \"Min trust diff:\", lhs_diff, \"Epsilon:\", eps,\n                  \"Layers:\", weights_layers[0][j], weights_layers[0][j + 1]\n              ]))\n      rhs_diff = tf.reduce_min(cond_direction *\n                               (weights_layers[max_main_dim][j + 1] -\n                                weights_layers[max_main_dim][j]))\n      asserts.append(\n          tf.Assert(\n              rhs_diff >= -eps,\n              data=[\n                  \"Trapezoid trust violation\", \"Feature indices:\", main_dim,\n                  \",\", cond_dim, \"Min trust diff:\", rhs_diff, \"Epsilon:\", eps,\n                  \"Layers:\", weights_layers[max_main_dim][j + 1],\n                  weights_layers[max_main_dim][j]\n              ]))\n\n  for dominant_dim, weak_dim in monotonic_dominances or []:\n    weights_layers = _unstack_nd(weights, [dominant_dim, weak_dim])\n    for i in range(lattice_sizes[dominant_dim] - 1):\n      for j in range(lattice_sizes[weak_dim] - 1):\n        midpoint = (weights_layers[i + 1][j + 1] + weights_layers[i][j]) / 2\n        dominant_diff = tf.reduce_min(weights_layers[i + 1][j] - midpoint)\n        asserts.append(\n            tf.Assert(\n                dominant_diff >= -eps,\n                data=[\n                    \"Dominance violation\", \"Feature indices:\", dominant_dim,\n                    \",\", weak_dim, \"Min dominance diff:\", dominant_diff,\n                    \"Epsilon:\", eps, \"Layers:\", weights_layers[i][j],\n                    weights_layers[i + 1][j], weights_layers[i + 1][j + 1]\n                ]))\n        weak_diff = tf.reduce_min(midpoint - weights_layers[i][j + 1])\n        asserts.append(\n            tf.Assert(\n                weak_diff >= -eps,\n                data=[\n                    \"Dominance violation\", \"Feature indices:\", dominant_dim,\n                    \",\", weak_dim, \"Min dominance diff:\", weak_diff, \"Epsilon:\",\n                    eps, \"Layers:\", weights_layers[i][j],\n                    weights_layers[i + 1][j], weights_layers[i + 1][j + 1]\n                ]))\n\n  for dominant_dim, weak_dim in range_dominances or []:\n    weights_layers = _unstack_nd(weights, [dominant_dim, weak_dim])\n    dom_dim_size = lattice_sizes[dominant_dim]\n    weak_dim_size = lattice_sizes[weak_dim]\n    for i in range(dom_dim_size):\n      for j in range(weak_dim_size):\n        diff = tf.reduce_min(\n            (weights_layers[dom_dim_size - 1][j] - weights_layers[0][j]) -\n            (weights_layers[i][weak_dim_size - 1] - weights_layers[i][0]))\n        asserts.append(\n            tf.Assert(\n                diff >= -eps,\n                data=[\n                    \"Range dominance violation\", \"Feature indices:\",\n                    dominant_dim, \",\", weak_dim, \"Min dominance diff:\", diff,\n                    \"Epsilon:\", eps, \"Layers:\",\n                    weights_layers[dom_dim_size - 1][j], weights_layers[0][j],\n                    weights_layers[i][weak_dim_size - 1], weights_layers[i][0]\n                ]))\n\n  for dim1, dim2 in joint_monotonicities or []:\n    weights_layers = _unstack_nd(weights, [dim1, dim2])\n    for i in range(lattice_sizes[dim1] - 1):\n      for j in range(lattice_sizes[dim2] - 1):\n        midpoint = (weights_layers[i + 1][j] + weights_layers[i][j + 1]) / 2\n        lower_triangle_diff = tf.reduce_min(weights_layers[i + 1][j + 1] -\n                                            midpoint)\n        asserts.append(\n            tf.Assert(\n                lower_triangle_diff >= -eps,\n                data=[\n                    \"Joint monotonicity violation\", \"Feature indices:\", dim1,\n                    \",\", dim2, \"Min lower triangle diff:\", lower_triangle_diff,\n                    \"Epsilon:\", eps, \"Layers:\", weights_layers[i + 1][j + 1],\n                    weights_layers[i + 1][j], weights_layers[i][j + 1]\n                ]))\n        upper_triangle_diff = tf.reduce_min(midpoint - weights_layers[i][j])\n        asserts.append(\n            tf.Assert(\n                upper_triangle_diff >= -eps,\n                data=[\n                    \"Joint monotonicity violation\", \"Feature indices:\", dim1,\n                    \",\", dim2, \"Min upper triangle diff:\", upper_triangle_diff,\n                    \"Epsilon:\", eps, \"Layers:\", weights_layers[i][j],\n                    weights_layers[i + 1][j], weights_layers[i][j + 1]\n                ]))\n\n  if output_min is not None:\n    min_weight = tf.reduce_min(weights)\n    asserts.append(\n        tf.Assert(\n            min_weight >= output_min - eps,\n            data=[\n                \"Lower bound violation.\", \"output_min:\", output_min,\n                \"Smallest weight:\", min_weight, \"Epsilon:\", eps, \"Weights:\",\n                weights\n            ]))\n\n  if output_max is not None:\n    max_weight = tf.reduce_max(weights)\n    asserts.append(\n        tf.Assert(\n            max_weight <= output_max + eps,\n            data=[\n                \"Upper bound violation.\", \"output_max:\", output_max,\n                \"Largest weight:\", max_weight, \"Epsilon:\", eps, \"Weights:\",\n                weights\n            ]))\n  return asserts\n\n\ndef _unstack_nested_lists(tensor_or_list, axis):\n  \"\"\"Unstacks tensors stored within nested list.\"\"\"\n  if isinstance(tensor_or_list, list):\n    return [_unstack_nested_lists(item, axis) for item in tensor_or_list]\n  else:\n    return tf.unstack(tensor_or_list, axis=axis)\n\n\ndef _unstack_nd(tensor, dims):\n  \"\"\"Returns nested lists of tensors resulting from n unstack operations.\"\"\"\n  dims = list(dims)\n  # Following unstack operations will remove some dims. It will result in dims\n  # higher than removed dims shifting left. So update passed in dims to reflect\n  # shift resulted from tf.unstack() in advance.\n  for i in range(len(dims) - 1, 0, -1):\n    dims[i] -= sum([dims[i] > previous_dims for previous_dims in dims[:i]])\n\n  result = tensor\n  for dim in dims:\n    result = _unstack_nested_lists(result, axis=dim)\n  return result\n\n\ndef _stack_nested_lists(tensor_or_list, axis):\n  \"\"\"Stacks tensors stored within nested list..\"\"\"\n  if isinstance(tensor_or_list[0], list):\n    return [_stack_nested_lists(item, axis) for item in tensor_or_list]\n  else:\n    return tf.stack(tensor_or_list, axis=axis)\n\n\ndef _stack_nd(tensor, dims):\n  \"\"\"Returns tensor that re-stacks tensor layers formed from unstacking.\"\"\"\n  dims = list(dims)\n  # Following stack operations will add some dims. It will result in dims higher\n  # than removed dims shifting right. So update passed in dims to reflect\n  # shift resulted from tf.stack() in advance.\n  for i in range(len(dims) - 1, 0, -1):\n    dims[i] -= sum([dims[i] > previous_dims for previous_dims in dims[:i]])\n\n  result = tensor\n  for dim in reversed(dims):\n    result = _stack_nested_lists(result, axis=dim)\n  return result\n\n\ndef _get_element(lists, indices):\n  \"\"\"Gets element from nested lists of arbitrary depth.\"\"\"\n  result = lists\n  for i in indices:\n    result = result[i]\n  return result\n\n\ndef _set_element(lists, indices, value):\n  \"\"\"Sets element into nested lists of arbitrary depth.\"\"\"\n  result = lists\n  for i in indices[:-1]:\n    result = result[i]\n  result[indices[-1]] = value\n\n\ndef _reverse_second_list_dimension(layers):\n  \"\"\"Reverses each list within a list of lists, but not the outer list.\"\"\"\n  return [layer[::-1] for layer in layers]\n"
  },
  {
    "path": "tensorflow_lattice/python/lattice_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Lattice Layer.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import lattice_layer as ll\nfrom tensorflow_lattice.python import test_utils\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass LatticeTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(LatticeTest, self).setUp()\n    self.disable_all = False\n    self.disable_ensembles = False\n    self.loss_eps = 0.0001\n    self.small_eps = 1e-6\n    keras.utils.set_random_seed(42)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  def _ScatterXUniformly(self, num_points, lattice_sizes):\n    \"\"\"Deterministically generates num_point random points within lattice.\"\"\"\n    np.random.seed(41)\n    x = []\n    for _ in range(num_points):\n      point = [\n          np.random.random() * (num_vertices - 1.0)\n          for num_vertices in lattice_sizes\n      ]\n      x.append(np.asarray(point))\n    if len(lattice_sizes) == 1:\n      x.sort()\n    return x\n\n  def _ScatterXUniformlyExtendedRange(self, num_points, lattice_sizes):\n    \"\"\"Extends every dimension by 1.0 on both sides and generates points.\"\"\"\n    np.random.seed(41)\n    x = []\n    for _ in range(num_points):\n      point = [\n          np.random.random() * (num_vertices + 1.0) - 1.0\n          for num_vertices in lattice_sizes\n      ]\n      x.append(np.asarray(point))\n    if len(lattice_sizes) == 1:\n      x.sort()\n    return x\n\n  def _SameValueForAllDims(self, num_points, lattice_sizes):\n    \"\"\"Generates random point with same value for every dimension.\"\"\"\n    if lattice_sizes.count(lattice_sizes[0]) != len(lattice_sizes):\n      raise ValueError(\"All dimensions must be of same size. \"\n                       \"They are: {}\".format(lattice_sizes))\n    np.random.seed(41)\n    x = []\n    for _ in range(num_points):\n      rand = np.random.random() * (lattice_sizes[0] - 1.0)\n      point = [rand] * len(lattice_sizes)\n      x.append(np.asarray(point))\n    if len(lattice_sizes) == 1:\n      x.sort()\n    return x\n\n  def _TwoDMeshGrid(self, num_points, lattice_sizes):\n    \"\"\"Mesh grid for visualisation of 3-d surfaces via pyplot.\"\"\"\n    if len(lattice_sizes) != 2:\n      raise ValueError(\"2-d mesh grid is possible only for 2-d lattice. Lattice\"\n                       \" sizes given: %s\" % lattice_sizes)\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points,\n        x_min=0.0,\n        y_min=0.0,\n        x_max=lattice_sizes[0] - 1.0,\n        y_max=lattice_sizes[1] - 1.0)\n\n  def _TwoDMeshGridExtendedRange(self, num_points, lattice_sizes):\n    \"\"\"Mesh grid extended by 1.0 on every side.\"\"\"\n    if len(lattice_sizes) != 2:\n      raise ValueError(\"2-d mesh grid is possible only for 2-d lattice. Lattice\"\n                       \" sizes given: %s\" % lattice_sizes)\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points,\n        x_min=-1.0,\n        y_min=-1.0,\n        x_max=lattice_sizes[0],\n        y_max=lattice_sizes[1])\n\n  def _Sin(self, x):\n    return math.sin(x[0])\n\n  def _SinPlusX(self, x):\n    return math.sin(x[0]) + x[0] / 3.0\n\n  def _SinPlusLargeX(self, x):\n    return math.sin(x[0]) + x[0]\n\n  def _SinPlusXNd(self, x):\n    return np.sum([math.sin(y) + y / 5.0 for y in x])\n\n  def _SinOfSum(self, x):\n    return math.sin(sum(x))\n\n  def _Square(self, x):\n    return x[0]**2\n\n  def _Max(self, x):\n    return np.amax(x)\n\n  def _WeightedSum(self, x):\n    result = 0.0\n    for i in range(len(x)):\n      result += (i + 1.0) * x[i]\n    return result\n\n  def _MixedSignWeightedSum(self, x):\n    result = 0.0\n    for i in range(len(x)):\n      sign = (i % 2) * -2 + 1\n      result += sign * (i + 1.0) * x[i]\n    return result\n\n  def _PseudoLinear(self, x):\n    result = 0.0\n    for i in range(len(x)):\n      result += 2 * x[i]\n      for j in range(len(x)):\n        if i != j:\n          result += x[i] * x[j]\n    return result\n\n  def _ScaledSum(self, x):\n    result = 0.0\n    for y in x:\n      result += y / len(x)\n    return result\n\n  def _GetMultiOutputInitializer(self, weights):\n    \"\"\"Tiles given weights along 'units' dimension.\"\"\"\n\n    def Initializer(shape, dtype):\n      return tf.tile(\n          tf.constant(weights, shape=[len(weights), 1], dtype=dtype),\n          multiples=[1, shape[1]])\n\n    return Initializer\n\n  def _GetTrainingInputsAndLabels(self, config):\n    \"\"\"Generates training inputs and labels.\n\n    Args:\n      config: Dictionary with config for this unit test.\n\n    Returns:\n      Tuple `(training_inputs, training_labels)` where\n        `training_inputs` and `training_labels` are data for training.\n    \"\"\"\n    raw_training_inputs = config[\"x_generator\"](\n        num_points=config[\"num_training_records\"],\n        lattice_sizes=config[\"lattice_sizes\"])\n\n    if isinstance(raw_training_inputs, tuple):\n      # This means that raw inputs are 2-d mesh grid. Convert them into list of\n      # 2-d points.\n      training_inputs = list(np.dstack(raw_training_inputs).reshape((-1, 2)))\n    else:\n      training_inputs = raw_training_inputs\n\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n    return training_inputs, training_labels\n\n  def _SetDefaults(self, config):\n    config.setdefault(\"monotonicities\", None)\n    config.setdefault(\"unimodalities\", None)\n    config.setdefault(\"edgeworth_trusts\", None)\n    config.setdefault(\"trapezoid_trusts\", None)\n    config.setdefault(\"monotonic_dominances\", None)\n    config.setdefault(\"range_dominances\", None)\n    config.setdefault(\"joint_monotonicities\", None)\n    config.setdefault(\"joint_unimodalities\", None)\n    config.setdefault(\"output_min\", None)\n    config.setdefault(\"output_max\", None)\n    config.setdefault(\"signal_name\", \"TEST\")\n    config.setdefault(\"kernel_initializer\", \"linear_initializer\")\n    config.setdefault(\"num_projection_iterations\", 10)\n    config.setdefault(\"monotonic_at_every_step\", True)\n    config.setdefault(\"target_monotonicity_diff\", 0.0)\n    config.setdefault(\"kernel_regularizer\", None)\n    config.setdefault(\"units\", 1)\n    config.setdefault(\"lattice_index\", 0)\n    config.setdefault(\"interpolation\", \"hypercube\")\n\n    return config\n\n  def _TestEnsemble(self, config):\n    \"\"\"Verifies that 'units > 1' lattice produces same output as 'units==1'.\"\"\"\n    # Note that the initialization of the lattice must be the same across the\n    # units dimension (otherwise the loss will be different).\n    if self.disable_ensembles:\n      return\n    config = dict(config)\n    config[\"num_training_epoch\"] = 3\n    losses = []\n    for units, lattice_index in [(1, 0), (3, 0), (3, 2)]:\n      config[\"units\"] = units\n      config[\"lattice_index\"] = lattice_index\n      losses.append(self._TrainModel(config))\n    self.assertAlmostEqual(min(losses), max(losses), delta=self.loss_eps)\n\n  def _TrainModel(self, config):\n    logging.info(\"Testing config:\")\n    logging.info(config)\n    config = self._SetDefaults(config)\n    self._ResetAllBackends()\n\n    training_inputs, training_labels = (\n        self._GetTrainingInputsAndLabels(config))\n\n    units = config[\"units\"]\n    lattice_sizes = config[\"lattice_sizes\"]\n    if units > 1:\n      # In order to test multi 'units' lattice replicate inputs 'units' times\n      # and later use just one out of 'units' outputs in order to ensure that\n      # multi 'units' lattice trains exactly similar to single 'units' one.\n      training_inputs = [\n          np.tile(np.expand_dims(x, axis=0), reps=[units, 1])\n          for x in training_inputs\n      ]\n      input_shape = (units, len(lattice_sizes))\n    else:\n      input_shape = (len(lattice_sizes),)\n\n    keras_layer = ll.Lattice(\n        lattice_sizes=lattice_sizes,\n        units=units,\n        monotonicities=config[\"monotonicities\"],\n        unimodalities=config[\"unimodalities\"],\n        edgeworth_trusts=config[\"edgeworth_trusts\"],\n        trapezoid_trusts=config[\"trapezoid_trusts\"],\n        monotonic_dominances=config[\"monotonic_dominances\"],\n        range_dominances=config[\"range_dominances\"],\n        joint_monotonicities=config[\"joint_monotonicities\"],\n        joint_unimodalities=config[\"joint_unimodalities\"],\n        output_min=config[\"output_min\"],\n        output_max=config[\"output_max\"],\n        num_projection_iterations=config[\"num_projection_iterations\"],\n        monotonic_at_every_step=config[\"monotonic_at_every_step\"],\n        interpolation=config[\"interpolation\"],\n        kernel_initializer=config[\"kernel_initializer\"],\n        kernel_regularizer=config[\"kernel_regularizer\"],\n        input_shape=input_shape,\n        dtype=tf.float32)\n    model = keras.models.Sequential()\n    model.add(keras_layer)\n\n    # When we use multi-unit lattices, we only extract a single lattice for\n    # testing.\n    if units > 1:\n      lattice_index = config[\"lattice_index\"]\n      model.add(\n          keras.layers.Lambda(lambda x: x[:, lattice_index:lattice_index + 1]))\n\n    optimizer = config[\"optimizer\"](learning_rate=config[\"learning_rate\"])\n    model.compile(loss=keras.losses.mean_squared_error, optimizer=optimizer)\n\n    training_data = (training_inputs, training_labels)\n    loss = test_utils.run_training_loop(\n        config=config, training_data=training_data, keras_model=model\n    )\n\n    if tf.executing_eagerly():\n      tf.print(\"final weights: \", keras_layer.kernel)\n    assetion_ops = keras_layer.assert_constraints(\n        eps=-config[\"target_monotonicity_diff\"])\n    if not tf.executing_eagerly() and assetion_ops:\n      tf.compat.v1.keras.backend.get_session().run(assetion_ops)\n\n    return loss\n\n  def testMonotonicityOneD(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [20],\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinPlusX,\n        \"monotonicities\": [1],\n        \"output_min\": 0.0,\n        \"output_max\": 7.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.110467, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [20],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": lambda x: -self._SinPlusX(x),\n        \"monotonicities\": [\"increasing\"],\n        \"output_min\": -7.0,\n        \"output_max\": 0.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 2.889168, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinPlusLargeX,\n        \"monotonicities\": [1],\n        \"output_min\": 0.0,\n        \"output_max\": 6.0,\n        # Target function is strictly increasing.\n        \"target_monotonicity_diff\": 0.02,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000781, delta=self.loss_eps)\n\n  def testMonotonicityTwoD(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [21, 6],\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": [1, 1],\n        \"output_min\": 0.0,\n        \"output_max\": 7.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.443284, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [6, 21],\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": [1, 1],\n        \"output_min\": 0.0,\n        \"output_max\": 7.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.443284, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [6, 21],\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": [\"none\", \"increasing\"],\n        \"output_min\": 0.0,\n        \"output_max\": 7.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.202527, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [6, 21],\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": [1, 0],\n        \"output_min\": 0.0,\n        \"output_max\": 7.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.244739, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": lambda x: -self._ScaledSum(x),\n        \"monotonicities\": [1, 1],\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.051462, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testMonotonicity5d(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 2, 2, 2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._ScaledSum,\n        \"monotonicities\": [1, 1, 1, 1, 1],\n        \"kernel_initializer\": keras.initializers.Constant(value=0.5),\n        # Function is strictly increasing everywhere, so request monotonicity\n        # diff to be strictly positive.\n        \"target_monotonicity_diff\": 0.08,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000002, delta=self.loss_eps)\n\n    config = {\n        \"lattice_sizes\": [2, 2, 2, 2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": lambda x: -self._ScaledSum(x),\n        \"monotonicities\": [1, 1, 1, 1, 1],\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.014971, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [3, 3, 3, 3],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": [1, \"increasing\", 1, 1],\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.358079, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([0, 1, 1],),\n      ([1, 0, 1],),\n      ([1, 1, 0],),\n  )\n  def testMonotonicityEquivalence(self, monotonicities):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 3, 3],\n        \"monotonicities\": monotonicities,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._SameValueForAllDims,\n        \"y_function\": self._SinOfSum,\n        \"kernel_initializer\": \"zeros\",\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000286, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testMonotonicity10dAlmostMonotone(self):\n    if self.disable_all:\n      return\n    np.random.seed(4411)\n    num_weights = 1024\n    weights = [1.0 * i / num_weights for i in range(num_weights)]\n    for _ in range(10):\n      i = int(np.random.random() * num_weights)\n      weights[i] = 0.0\n\n    config = {\n        \"lattice_sizes\": [2] * 10,\n        \"num_training_records\": 1000,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 100.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": test_utils.get_hypercube_interpolation_fn(weights),\n        \"monotonicities\": [1] * 10,\n        \"kernel_initializer\": \"zeros\",\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000027, delta=self.loss_eps)\n\n    config[\"monotonicities\"] = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000019, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testMonotonicity10dSinOfSum(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2] * 10,\n        \"num_training_records\": 1000,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 100.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": [1] * 10,\n        \"output_min\": -1.0,\n        \"output_max\": 1.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.089950, delta=self.loss_eps)\n\n    config[\"monotonicities\"] = [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.078830, delta=self.loss_eps)\n\n    config[\"monotonicities\"] = [0, 0, 0, 1, 0, 1, 0, 0, 0, 0]\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.052190, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(0, 1, 1)], [], 0.025785),\n      (None, [(0, 1, 1)], 0.042566),\n      ([(0, 1, \"positive\")], [(0, 1, \"positive\")], 0.042566),\n  )\n  def testSimpleTrustTwoD(self, edgeworth_trusts, trapezoid_trusts,\n                          expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._Max,\n        \"monotonicities\": [1, 0],\n        \"edgeworth_trusts\": edgeworth_trusts,\n        \"trapezoid_trusts\": trapezoid_trusts,\n        \"output_min\": 0.0,\n        \"output_max\": 1.0,\n        # Leave margin of error (floating point) for trust projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(1, 0, -1)], None, 3.23711),\n      (None, [(1, 0, -1)], 6.663453),\n      ([(1, 0, \"negative\")], [(1, 0, \"negative\")], 9.846122),\n  )\n  def testDenseTrustTwoD(self, edgeworth_trusts, trapezoid_trusts,\n                         expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [4, 3],\n        \"num_training_records\": 150,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._PseudoLinear,\n        \"monotonicities\": [0, 1],\n        \"edgeworth_trusts\": edgeworth_trusts,\n        \"trapezoid_trusts\": trapezoid_trusts,\n        \"output_min\": 0.0,\n        \"output_max\": 22.0,\n        # Leave margin of error (floating point) for trust projection.\n        \"target_monotonicity_diff\": -1e-5,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    if not edgeworth_trusts or not trapezoid_trusts:\n      self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(0, 1, 1)], None, 0.010525),\n      (None, [(0, 1, 1)], 0.013343),\n      ([(0, 1, 1)], [(0, 1, 1)], 0.013343),\n  )\n  def testSimpleTrust4D(self, edgeworth_trusts, trapezoid_trusts,\n                        expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 2, 2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Max,\n        \"monotonicities\": [1, 0, 1, 1],\n        \"edgeworth_trusts\": edgeworth_trusts,\n        \"trapezoid_trusts\": trapezoid_trusts,\n        \"output_min\": 0.0,\n        \"output_max\": 1.0,\n        # Leave margin of error (floating point) for trust projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(0, 1, 1), (3, 1, -1), (3, 2, 1)], None, 0.334325),\n      (None, [(0, 1, 1), (3, 1, -1), (3, 2, 1)], 0.387444),\n      ([(0, 1, 1), (3, 1, -1)], [(3, 1, -1), (3, 2, 1)], 0.381514),\n  )\n  def testMultiDenseTrust4D(self, edgeworth_trusts, trapezoid_trusts,\n                            expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 3, 3, 3],\n        \"num_training_records\": 1000,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": [1, 0, 0, 1],\n        \"edgeworth_trusts\": edgeworth_trusts,\n        \"trapezoid_trusts\": trapezoid_trusts,\n        \"output_min\": -0.5,\n        \"output_max\": 0.9,\n        # Leave margin of error (floating point) for trust projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    if not edgeworth_trusts or not trapezoid_trusts:\n      self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(0, 1, 1)],),\n      ([(1, 2, 1)],),\n      ([(2, 0, 1)],),\n  )\n  def testEdgeworthTrustEquivalence(self, edgeworth_trusts):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 3, 3],\n        \"monotonicities\": [1, 1, 1],\n        \"edgeworth_trusts\": edgeworth_trusts,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._SameValueForAllDims,\n        \"y_function\": self._PseudoLinear,\n        \"kernel_initializer\": \"zeros\",\n        # Leave margin of error (floating point) for trust projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.006912, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.00000),\n      ([(1, 0)], 0.00000),\n      ([(0, 1)], 0.05092),\n  )\n  def testSimpleMonotonicDominance2D(self, monotonic_dominances, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._WeightedSum,\n        \"monotonicities\": [1, 1],\n        \"monotonic_dominances\": monotonic_dominances,\n        \"output_min\": 0.0,\n        \"output_max\": 3.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.00113),\n      ([(1, 0)], 0.00113),\n      ([(0, 1)], 0.81520),\n  )\n  def testDenseMonotonicDominance2D(self, monotonic_dominances, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [5, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"num_projection_iterations\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._WeightedSum,\n        \"monotonicities\": [1, 1],\n        \"monotonic_dominances\": monotonic_dominances,\n        \"output_min\": 0.0,\n        \"output_max\": 12.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-2,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(1, 0), (2, 1)], 2.52985),\n      ([(0, 1), (1, 2)], 6.16700),\n  )\n  def testDenseMonotonicDominance5D(self, monotonic_dominances, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [5, 5, 5, 5, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 300,\n        \"num_projection_iterations\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n        \"monotonicities\": [1, 1, 1, 1, 1],\n        \"monotonic_dominances\": monotonic_dominances,\n        \"output_min\": 0.0,\n        \"output_max\": 60.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-1,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.00618),\n      ([(1, 0)], 0.00618),\n      ([(0, 1)], 0.05092),\n  )\n  def testSimpleRangeDominance2D(self, range_dominances, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.1,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._WeightedSum,\n        \"monotonicities\": [1, 1],\n        \"range_dominances\": range_dominances,\n        \"output_min\": 0.0,\n        \"output_max\": 3.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.24449, 1),\n      ([(1, 0)], 0.24449, 2),\n      ([(0, 1)], 0.61649, 3),\n  )\n  def testDenseRangeDominance2D(self, range_dominances, expected_loss, expid):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [5, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 40,\n        \"num_projection_iterations\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.1,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._WeightedSum,\n        \"monotonicities\": [1, 1],\n        \"range_dominances\": range_dominances,\n        \"output_min\": 0.0,\n        \"output_max\": 12.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-2,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(1, 0), (2, 1)], 1.24238),\n      ([(0, 1), (1, 2)], 2.14021),\n  )\n  def testDenseRangeDominance5D(self, range_dominances, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [5, 5, 5, 5, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 300,\n        \"num_projection_iterations\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n        \"monotonicities\": [1, 1, 1, 1, 1],\n        \"range_dominances\": range_dominances,\n        \"output_min\": 0.0,\n        \"output_max\": 60.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-1,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=0.01)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.00000),\n      ([(0, 1)], 0.05092),\n      ([(1, 0)], 0.05092),\n  )\n  def testSimpleJointMonotonicity2D(self, joint_monotonicities, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._MixedSignWeightedSum,\n        \"monotonicities\": [0, 0],\n        \"joint_monotonicities\": joint_monotonicities,\n        \"output_min\": -2.0,\n        \"output_max\": 1.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.001765),\n      (([0], \"valley\"), 0.306134),\n      (((0,), \"peak\"), 0.306134),\n  )\n  def testJointUnimodality1D(self, joint_unimodalities, expected_loss):\n    if self.disable_all:\n      return\n\n    def _Sin(x):\n      result = math.sin(x[0])\n      # Make test exactly symmetric for both unimodality directions.\n      if joint_unimodalities and joint_unimodalities[-1] == \"peak\":\n        result *= -1\n      return result\n\n    config = {\n        \"lattice_sizes\": [15],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": _Sin,\n        \"monotonicities\": [0],\n        \"joint_unimodalities\": joint_unimodalities,\n        \"output_min\": -1.0,\n        \"output_max\": 1.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testJointUnimodality2DSinOfSum(self):\n    # This test demonstrates difference of joint unimodaity vs independently\n    # unimofal dims. For latter loss would be 0.225369\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 3],\n        \"num_training_records\": 36*9,\n        \"num_training_epoch\": 150,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.1,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": lambda x: -math.sin(sum(x) * 2.0),\n        \"monotonicities\": [0, 0],\n        \"joint_unimodalities\": ([0, 1], \"peak\"),\n        \"output_min\": -1.0,\n        \"output_max\": 1.0,\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.136693, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.036196),\n      ([([0], \"valley\")], 0.221253),\n      ([([1], \"valley\")], 0.221253),\n      ([([0, 1], \"valley\")], 0.280938),\n      ([((1, 0), \"valley\")], 0.280938),\n  )\n  def testJointUnimodality2DWshaped(self, joint_unimodalities, expected_loss):\n    # Test larger lattice.\n    if self.disable_all:\n      return\n\n    center = (3, 3)\n\n    def WShaped2dFunction(x):\n      distance = lambda x1, y1, x2, y2: ((x2 - x1)**2 + (y2 - y1)**2)**0.5\n      d = distance(x[0], x[1], center[0], center[1])\n      t = (d - 0.6 * center[0])**2\n      return min(t, 6.0 - t)\n\n    config = {\n        \"lattice_sizes\": [coordinate * 2 + 1 for coordinate in center],\n        \"num_training_records\": 36 * 9,\n        \"num_training_epoch\": 18,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": WShaped2dFunction,\n        \"monotonicities\": [0, 0],\n        \"joint_unimodalities\": joint_unimodalities,\n        \"output_min\": 0.0,\n        \"output_max\": 3.0,\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (([0, 1], \"valley\"),),\n      (([1, 0], \"valley\"),),\n      (([0, 2], \"valley\"),),\n      (([0, 3], \"valley\"),),\n      (([3, 0], \"valley\"),),\n      (([1, 2], \"valley\"),),\n      (([1, 3], \"valley\"),),\n      (([3, 1], \"valley\"),),\n      (([2, 3], \"valley\"),),\n  )\n  def testJointUnimodality2OutOf4D(self, joint_unimodalities):\n    # Function is similar to 2dWshaped test. Data is generated identically for\n    # all combinations of unimodal pairs so loss should be same for any pair of\n    # dimensions constrained for unimodality.\n    if self.disable_all:\n      return\n\n    center = (2, 2)\n    center_indices = joint_unimodalities[0]\n\n    def WShaped2dFunction(x):\n      distance = lambda x1, y1, x2, y2: ((x2 - x1)**2 + (y2 - y1)**2)**0.5\n      d = distance(x[center_indices[0]], x[center_indices[1]], center[0],\n                   center[1])\n      t = (d - 0.6 * center[0])**2\n      return min(t, 4.5 - t)\n\n    def _DistributeXUniformly(num_points, lattice_sizes):\n      del num_points\n      points_per_vertex = 2\n      result = []\n      for i in range(0, lattice_sizes[0] * points_per_vertex + 1):\n        for j in range(0, lattice_sizes[1] * points_per_vertex + 1):\n          for k in range(0, lattice_sizes[2] * points_per_vertex + 1):\n            for l in range(0, lattice_sizes[3] * points_per_vertex + 1):\n              p = [\n                  i / float(points_per_vertex), j / float(points_per_vertex),\n                  k / float(points_per_vertex), l / float(points_per_vertex)\n              ]\n              result.append(p)\n      return result\n\n    lattice_sizes = [2] * 4\n    for i, center_value in zip(center_indices, center):\n      lattice_sizes[i] = center_value * 2 + 1\n\n    config = {\n        \"lattice_sizes\": lattice_sizes,\n        \"num_training_records\": 1,  # Not used by x_generator for this test.\n        \"num_training_epoch\": 10,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": _DistributeXUniformly,\n        \"y_function\": WShaped2dFunction,\n        \"monotonicities\": None,\n        \"joint_unimodalities\": [joint_unimodalities],\n        \"output_min\": 0.0,\n        \"output_max\": 3.0,\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.845696, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testJointUnimodality3D(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 3, 3, 3],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 30,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": [0, 0, 0, 0],\n        \"joint_unimodalities\": ([0, 1, 3], \"valley\"),\n        \"output_min\": -1.0,\n        \"output_max\": 1.0,\n        \"target_monotonicity_diff\": -1e-6,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.026094, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      (None, 0.16301),\n      ([(0, 1)], 0.86386),\n      ([(1, 0)], 0.86413),\n  )\n  def testDenseJointMonotonicity2D(self, joint_monotonicities, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [5, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 40,\n        \"num_projection_iterations\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._MixedSignWeightedSum,\n        \"monotonicities\": [0, 0],\n        \"joint_monotonicities\": joint_monotonicities,\n        \"output_min\": -8.0,\n        \"output_max\": 4.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-2,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([(0, 1)], 36.75898),)\n  def testDenseJointMonotonicity5D(self, joint_monotonicities, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [5, 5, 5, 5, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"num_projection_iterations\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._MixedSignWeightedSum,\n        \"monotonicities\": [0, 0, 0, 0, 0],\n        \"joint_monotonicities\": joint_monotonicities,\n        \"output_min\": -24.0,\n        \"output_max\": 36.0,\n        # Leave margin of error (floating point) for dominance projection.\n        \"target_monotonicity_diff\": -1e-1,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      # Custom TFL initializer:\n      (\"linear_initializer\", 0.126068),\n      # Standard Keras initializer:\n      (keras.initializers.Constant(value=1.5), 0.430379),\n      # Standard Keras initializer specified as string constant:\n      (\"zeros\", 1.488072),\n  )\n  def testInitializerType(self, initializer, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [2, 3],\n        \"num_training_records\": 98,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._Max,\n        \"output_min\": 0.0,\n        \"output_max\": 2.0,\n        \"kernel_initializer\": initializer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def _MergeDicts(self, x, y):\n    z = dict(x)\n    z.update(y)\n    return z\n\n  def testLinearMonotonicInitializer(self):\n    if self.disable_all:\n      return\n    # Test initializer by training linear function using 0 iteration and verify\n    # that loss is 0.\n    config = {\n        \"num_training_records\": 96,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n    }  # pyformat: disable\n\n    init_config = {\n        \"lattice_sizes\": [3, 4],\n        \"monotonicities\": [0, 0],\n        \"output_min\": -1.0,\n        \"output_max\": 2.0,\n    }\n    config[\"kernel_initializer\"] = \"LinearInitializer\"\n    config[\"y_function\"] = test_utils.get_linear_lattice_interpolation_fn(\n        **init_config)\n    total_config = self._MergeDicts(config, init_config)\n    loss = self._TrainModel(total_config)\n    self.assertAlmostEqual(loss, 0.0, delta=self.small_eps)\n    self._TestEnsemble(total_config)\n\n    # Change generator since we need more than 2 dimensions from now on.\n    config[\"x_generator\"] = self._ScatterXUniformly\n\n    init_config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"monotonicities\": [1, 1, 0, 1],\n        \"output_min\": 12.0,\n        \"output_max\": 22.0,\n    }\n    config[\"kernel_initializer\"] = ll.LinearInitializer(**init_config)\n    config[\"y_function\"] = test_utils.get_linear_lattice_interpolation_fn(\n        **init_config)\n    total_config = self._MergeDicts(config, init_config)\n    loss = self._TrainModel(total_config)\n    self.assertAlmostEqual(loss, 0.0, delta=self.small_eps)\n    self._TestEnsemble(total_config)\n\n    init_config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"monotonicities\": [0, 1, 0, 1],\n        \"output_min\": -10,\n        \"output_max\": -5,\n    }\n    config[\"kernel_initializer\"] = ll.LinearInitializer(**init_config)\n    config[\"y_function\"] = test_utils.get_linear_lattice_interpolation_fn(\n        **init_config)\n    total_config = self._MergeDicts(config, init_config)\n    loss = self._TrainModel(total_config)\n    self.assertAlmostEqual(loss, 0.0, delta=self.small_eps)\n    self._TestEnsemble(total_config)\n\n    # Try to fit some other function and see loss >0 to ensure that this test\n    # does not always returns 0.\n    config[\"y_function\"] = self._SinOfSum\n    total_config = self._MergeDicts(config, init_config)\n    loss = self._TrainModel(total_config)\n    self.assertGreater(loss, 0.1)\n    self._TestEnsemble(total_config)\n\n    init_config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"monotonicities\": [0, 0, 0, 0],\n        \"output_min\": 1.0,\n        \"output_max\": 3.0,\n    }\n    config[\"kernel_initializer\"] = \"linear_initializer\"\n    config[\"y_function\"] = test_utils.get_linear_lattice_interpolation_fn(\n        **init_config)\n    total_config = self._MergeDicts(config, init_config)\n    loss = self._TrainModel(total_config)\n    self.assertAlmostEqual(loss, 0.0, delta=self.small_eps)\n    self._TestEnsemble(total_config)\n\n  def testUnimodalInitializer(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 4],\n        \"unimodalities\": [1, 1],\n        \"kernel_initializer\": \"linear_initializer\",\n        \"num_training_records\": 96,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._Max,\n        \"output_min\": 0.0,\n        \"output_max\": 2.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 1.292362, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config[\"unimodalities\"] = [\"valley\", \"none\"]\n    config[\"monotonicities\"] = [\"none\", \"increasing\"]\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.794330, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config[\"unimodalities\"] = [\"peak\", \"none\"]\n    config[\"monotonicities\"] = [\"none\", \"increasing\"]\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 1.082982, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testRandomMonotonicInitializer(self):\n    if self.disable_all:\n      return\n    lattice_sizes = [2, 2]\n    units = 1\n    monotonicities = [1, 1]\n    output_min = 0.0\n    output_max = 1.0\n    kernel_initializer = ll.RandomMonotonicInitializer(\n        lattice_sizes=lattice_sizes,\n        output_min=output_min,\n        output_max=output_max)\n    input_shape = (len(lattice_sizes),)\n\n    first_random_lattice = ll.Lattice(\n        lattice_sizes=lattice_sizes,\n        units=units,\n        monotonicities=monotonicities,\n        output_min=output_min,\n        output_max=output_max,\n        kernel_initializer=kernel_initializer,\n        input_shape=input_shape,\n        dtype=tf.float32)\n    first_random_lattice.build(input_shape)\n    first_weights = first_random_lattice.get_weights()\n\n    second_random_lattice = ll.Lattice(\n        lattice_sizes=lattice_sizes,\n        units=units,\n        monotonicities=monotonicities,\n        output_min=output_min,\n        output_max=output_max,\n        kernel_initializer=kernel_initializer,\n        input_shape=input_shape,\n        dtype=tf.float32)\n    second_random_lattice.build(input_shape)\n    second_weights = second_random_lattice.get_weights()\n\n    # Assert Constraints on Lattice\n    first_random_lattice.assert_constraints(eps=1e-6)\n    second_random_lattice.assert_constraints(eps=1e-6)\n    # Assert Weight Bounds And Randomness\n    self.assertAllInRange(first_weights, output_min, output_max)\n    self.assertAllInRange(second_weights, output_min, output_max)\n    self.assertNotAllEqual(first_weights, second_weights)\n\n  def testAssertMonotonicity(self):\n    if self.disable_all:\n      return\n    # Specify non monotonic initializer and do 0 training iterations so no\n    # projections are being executed.\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._ScaledSum,\n        \"monotonicities\": [0, 0],\n        \"kernel_initializer\": self._GetMultiOutputInitializer(\n            weights=[4.0, 3.0, 2.0, 1.0])\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 4.865740, delta=self.loss_eps)\n\n    for monotonicity in [[0, 1], [1, 0], [1, 1]]:\n      for units in [1, 3]:\n        config[\"monotonicities\"] = monotonicity\n        config[\"units\"] = units\n        with self.assertRaises(tf.errors.InvalidArgumentError):\n          self._TrainModel(config)\n\n  def testBounds(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [20],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Sin,\n        \"output_min\": -0.6,\n        \"output_max\": 0.4,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.109398, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [11, 4],\n        \"num_training_records\": 270,\n        \"num_training_epoch\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": self._SinPlusXNd,\n        \"monotonicities\": [1, 1],\n        \"output_min\": 1.0,\n        \"output_max\": 2.5,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.380813, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2] * 5,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 40,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"monotonicities\": [1, 1, 0, 1, 0],\n        \"output_min\": 0.3,\n        \"output_max\": 0.7,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.145910, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testInputOutOfBounds(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [6],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformlyExtendedRange,\n        \"y_function\": self._Sin,\n        \"kernel_initializer\": keras.initializers.Zeros(),\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.018727, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGridExtendedRange,\n        \"y_function\": self._SinOfSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.130813, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      # Laplacian with l1 and l2:\n      ((\"laplacian\", 0.005, 0.01), 0.03, 0.021399),\n      # Different regularization amount for every dimension:\n      ((\"laplacian\", [0.005, 0.01], [0.01, 0.02]), 0.045, 0.027941),\n      # Torsion with l1 and l2:\n      ((\"torsion\", 0.1, 0.01), 0.11, 0.06738),\n      # Different regularization amount for every dimension:\n      ((\"torsion\", [2.0, 0.05], [0.1, 0.1]), 0.11, 0.06738),\n      # List of regularizers:\n      ([(\"torsion\", 0.1, 0.0), (\"Torsion\", 0.0, 0.01)], 0.11, 0.06738),\n      # Standard Keras regularizer:\n      (keras.regularizers.l1_l2(l1=0.01, l2=0.1), 0.33, 0.214418),\n  )\n  def testRegularizers2d(self, regularizer, pure_reg_loss, training_loss):\n    if self.disable_all:\n      return\n    weights = [0.0, 1.0, 1.0, 1.0]\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": test_utils.get_hypercube_interpolation_fn(\n            coefficients=weights),\n        \"kernel_initializer\": self._GetMultiOutputInitializer(weights=weights),\n        \"kernel_regularizer\": regularizer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    # This loss is pure regularization loss because initializer matches target\n    # function and there was 0 training epochs.\n    self.assertAlmostEqual(loss, pure_reg_loss, delta=self.loss_eps)\n\n    multioutput_config = dict(config)\n    units = 3\n    multioutput_config[\"units\"] = units\n    loss = self._TrainModel(multioutput_config)\n    self.assertAlmostEqual(loss, pure_reg_loss * units, delta=self.loss_eps)\n\n    config[\"num_training_epoch\"] = 20\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, training_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      ((\"torsion\", 0.001, 0.0001), 0.147405),\n      ((\"laplacian\", 0.001, 0.0001), 0.193870),\n  )\n  def testRegularizersLargeLattice(self, regularizer, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [3, 4, 3, 4],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"kernel_regularizer\": regularizer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  def testHighDimensionsStressTest(self):\n    if self.disable_all:\n      return\n    lattice_sizes = [3, 3] + [2] * 14\n    monotonicities = [0] * 16\n    monotonicities[3], monotonicities[4], monotonicities[10] = (1, 1, 1)\n    unimodalities = [0] * 16\n    unimodalities[1] = 1\n    config = {\n        \"lattice_sizes\": lattice_sizes,\n        \"units\": 2,\n        \"monotonicities\": monotonicities,\n        \"unimodalities\": unimodalities,\n        \"edgeworth_trusts\": [(3, 2, 1)],\n        \"output_min\": 0.0,\n        \"output_max\": 1.0,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 3,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1000.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SinOfSum,\n        \"kernel_regularizer\": [(\"torsion\", 0.0, 1e-6),\n                               (\"laplacian\", 1e-6, 0.0)],\n        \"target_monotonicity_diff\": -1e-5,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    # Delta is large because regularizers for large lattice are prone to\n    # numerical errors due to summing up huge number of floats of various\n    # magnitudes hence loss is different in graph and eager modes.\n    self.assertAlmostEqual(loss, 0.97806, delta=0.05)\n\n  @parameterized.parameters(\n      ([0], [0], 0.026734),\n      ([1], [\"none\"], 0.195275),\n      ([1], None, 0.195275),\n      ([0], [\"Valley\"], 0.045627),\n      ([0], [\"peak\"], 0.045627),\n      ([0], [-1], 0.045627),\n      (None, [1], 0.045627),\n  )\n  def testUnimodalityOneD(self, monotonicities, unimodalities, expected_loss):\n    if self.disable_all:\n      return\n\n    def WShaped1dFunction(x):\n      d = min(abs(x[0] - 3.0), abs(x[0] - 7.0))\n      result = d * d / 4.0 - 2.0\n      # Mirroring to test opposite unimodality direction on same data.\n      if unimodalities:\n        if unimodalities[0] == -1 or unimodalities[0] == \"peak\":\n          result *= -1.0\n      return result\n\n    config = {\n        \"lattice_sizes\": [11],\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": WShaped1dFunction,\n        \"monotonicities\": monotonicities,\n        \"unimodalities\": unimodalities,\n        \"kernel_initializer\": \"linear_initializer\",\n        \"output_min\": -2.0,\n        \"output_max\": 2.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([0, 0], [0, 0], 0.003822),\n      ([1, 1], [0, 0], 0.313155),\n      ([0, 0], [1, 1], 0.003073),\n      ([1, 0], [0, 1], 0.162484),\n      ([0, 0], [1, 0], 0.004883),\n      ([0, 0], [-1, -1], 0.003073),\n      ([1, 0], [0, -1], 0.162484),\n      ([0, 0], [-1, 0], 0.004883),\n      ([0, 0], [-1, 1], 0.260546),\n  )\n  def testUnimodalityTwoD(self, monotonicities, unimodalities, expected_loss):\n    if self.disable_all:\n      return\n\n    def WShaped2dFunction(x):\n      distance = lambda x1, y1, x2, y2: ((x2 - x1)**2 + (y2 - y1)**2)**0.5\n      d = distance(x[0], x[1], 5.0, 5.0)\n      result = (d - 2.0)**2 / 8.0 - 2.0\n      # Mirroring to test opposite unimodality direction on same data.\n      if unimodalities[0] == -1 or unimodalities[1] == -1:\n        result *= -1.0\n      return result\n\n    config = {\n        \"lattice_sizes\": [11, 11],\n        \"num_training_records\": 900,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"y_function\": WShaped2dFunction,\n        \"monotonicities\": monotonicities,\n        \"unimodalities\": unimodalities,\n        \"kernel_initializer\": \"linear_initializer\",\n        \"output_min\": -2.0,\n        \"output_max\": 2.0,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  def testUnconstrained(self):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": [20],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Sin,\n        \"kernel_initializer\": keras.initializers.Zeros,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000917, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Square,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.004277, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": test_utils.get_hypercube_interpolation_fn(\n            coefficients=[0.0, 1.0, 1.0, 1.0]),\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000003, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2] * 3,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": test_utils.get_hypercube_interpolation_fn(\n            coefficients=[i / 2.0**3 for i in range(2**3)])\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000001, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2] * 5,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": test_utils.get_hypercube_interpolation_fn(\n            coefficients=[i / 2.0**5 for i in range(2**5)])\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000008, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Max,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.003599, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2] * 6,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 300,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 30.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._PseudoLinear,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000118, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 3, 4],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._PseudoLinear,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.00002, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [4, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 30.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Max,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000891, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 30.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.004216, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [20],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Sin,\n        \"kernel_initializer\": keras.initializers.Zeros,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000917, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 50,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Square,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.004277, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 2],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Max,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 5e-06, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2] * 6,\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 300,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 30.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._PseudoLinear,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.08056, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 3, 4],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._PseudoLinear,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.04316, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [4, 5],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 100,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 30.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._Max,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.000122, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n    config = {\n        \"lattice_sizes\": [2, 3, 4, 5],\n        \"interpolation\": \"simplex\",\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 30.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.003793, delta=self.loss_eps)\n    self._TestEnsemble(config)\n\n  @parameterized.parameters(\n      ([2, 3, 4], 6.429155),\n      ([2, 3, 3], 13.390955),\n      ([2, 2, 3], 22.205267),\n      ([2, 2, 3, 3], 5.049051),\n      ([2, 2, 3, 2, 2], 5.3823),\n      ([2, 2, 3, 3, 2, 2], 67.775276),\n      ([2, 2, 2, 3, 3, 3], 156.755035),\n      ([3, 2, 2, 3, 3, 2], 104.419373),\n  )\n  def testEqaulySizedDimsOptimization(self, lattice_sizes, expected_loss):\n    if self.disable_all:\n      return\n    config = {\n        \"lattice_sizes\": lattice_sizes,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 1,\n        \"optimizer\": keras.optimizers.legacy.Adagrad,\n        \"learning_rate\": 10.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WeightedSum,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self.loss_eps)\n\n  @parameterized.parameters(\n      ([2, 2, 2, 2, 2, 2], 81),\n      ([2, 2, 3, 2, 3, 2], 117),\n      ([2, 2, 2, 2, 3, 3], 102),\n      ([2, 2, 2, 2, 2, 2, 2, 2, 2], 114),\n      ([2, 2, 2, 2, 2, 2, 3, 3, 3], 135),\n  )\n  def testGraphSize(self, lattice_sizes, expected_graph_size):\n    # If this test failed then you modified core lattice interpolation logic in\n    # a way which increases number of ops in the graph. Or maybe Keras team\n    # changed something under the hood. Please ensure that this increase is\n    # unavoidable and try to minimize it.\n    if self.disable_all:\n      return\n    tf.compat.v1.disable_eager_execution()\n    tf.compat.v1.reset_default_graph()\n\n    layer = ll.Lattice(lattice_sizes=lattice_sizes)\n    input_tensor = tf.ones(shape=(1, len(lattice_sizes)))\n    layer(input_tensor)\n    graph_size = len(tf.compat.v1.get_default_graph().as_graph_def().node)\n\n    self.assertLessEqual(graph_size, expected_graph_size)\n\n  @parameterized.parameters(\n      (\n          \"random_uniform_or_linear_initializer\",\n          [3, 3, 3],\n          [([0, 1, 2], \"peak\")],\n          keras.initializers.RandomUniform,\n      ),\n      (\n          \"random_uniform_or_linear_initializer\",\n          [3, 3, 3],\n          [([0, 1, 2], \"valley\")],\n          keras.initializers.RandomUniform,\n      ),\n      (\n          \"random_uniform_or_linear_initializer\",\n          [3, 3, 3],\n          [([0, 1], \"valley\")],\n          ll.LinearInitializer,\n      ),\n      (\n          \"random_uniform_or_linear_initializer\",\n          [3, 3, 3],\n          [([0, 1], \"valley\"), ([2], \"peak\")],\n          ll.LinearInitializer,\n      ),\n      (\n          \"random_uniform_or_linear_initializer\",\n          [3, 3, 3],\n          None,\n          ll.LinearInitializer,\n      ),\n      (\n          \"linear_initializer\",\n          [3, 3, 3],\n          [([0, 1], \"valley\")],\n          ll.LinearInitializer,\n      ),\n      (\n          \"random_monotonic_initializer\",\n          [3, 3, 3],\n          [([0, 1], \"valley\")],\n          ll.RandomMonotonicInitializer,\n      ),\n  )\n  def testCreateKernelInitializer(self, kernel_initializer_id, lattice_sizes,\n                                  joint_unimodalities, expected_type):\n    self.assertEqual(\n        expected_type,\n        type(\n            ll.create_kernel_initializer(\n                kernel_initializer_id,\n                lattice_sizes,\n                monotonicities=None,\n                output_min=0.0,\n                output_max=1.0,\n                unimodalities=None,\n                joint_unimodalities=joint_unimodalities)))\n\n  @parameterized.parameters(\n      # Single Unit\n      (\n          [2, 2],\n          [[0.], [1.], [2.], [3.]],\n          [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]],\n          [[0.], [1.], [2.], [3.]],\n      ),\n      (\n          [3, 2],\n          [[-0.4], [0.9], [0.4], [-0.6], [-0.8], [0.6]],\n          [[0.8, 0.3], [0.3, 0.8], [2.0, 0.0], [2.0, 0.5], [2.0, 1.0]],\n          [[-0.06], [0.19], [-0.8], [-0.1], [0.6]],\n      ),\n      (\n          [2, 2, 2, 2, 2],\n          [[-0.2], [-0.7], [-0.8], [0.8], [-0.3], [-0.6], [0.4], [0.5], [-0.3],\n           [0.3], [0.9], [0.4], [0.3], [-0.7], [0.1], [0.8], [-0.7], [-0.6],\n           [0.9], [-0.2], [0.3], [0.2], [0.9], [-0.1], [-0.6], [0.8], [0.4],\n           [1], [0.5], [0.2], [0.8], [-0.8]],\n          [[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1]],\n          [[-0.04], [-0.18]],\n      ),\n      (\n          [3, 2, 2],\n          [[0], [1], [0.5], [0.1], [-0.5], [-0.9], [0.6], [-0.7], [-0.4], [0.2],\n           [0], [0.8]],\n          [[0.1, 0.2, 0.3], [0.3, 0.2, 0.1], [1.1, 0.2, 0.3], [1.7, 0.2, 0.1]],\n          [[0.04], [-0.06], [-0.43], [-0.27]],\n      ),\n      # Multi Unit\n      (\n          [2, 2],\n          [\n              [1., 11., 111.],\n              [2., 22., 222.],\n              [3., 33., 333.],\n              [4., 44., 444.],\n          ],\n          [\n              [[0.0, 0.0], [0.0, 0.0], [1.0, 1.0]],\n              [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0]],\n              [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0]],\n              [[1.0, 1.0], [1.0, 1.0], [0.0, 0.0]],\n          ],\n          [\n              [1., 11., 444.],\n              [2., 22., 333.],\n              [3., 33., 222.],\n              [4., 44., 111.],\n          ],\n      ),\n      (\n          [3, 2],\n          [\n              [-0.4, -4, -40, -400],\n              [0.9, 9, 90, 900],\n              [0.4, 4, 40, 400],\n              [-0.6, -6, -60, -600],\n              [-0.8, -8, -80, -800],\n              [0.6, 6, 60, 600],\n          ],\n          [\n              [[0.8, 0.3], [2.0, 1.0], [0.8, 0.3], [2.0, 1.0]],\n              [[0.3, 0.8], [2.0, 0.5], [0.3, 0.8], [2.0, 0.5]],\n              [[2.0, 0.0], [2.0, 0.0], [2.0, 0.0], [2.0, 0.0]],\n              [[2.0, 0.5], [0.3, 0.8], [2.0, 0.5], [0.3, 0.8]],\n              [[2.0, 1.0], [0.8, 0.3], [2.0, 1.0], [0.8, 0.3]],\n          ],\n          [\n              [-0.06, 6., -6., 600.],\n              [0.19, -1., 19., -100.],\n              [-0.8, -8., -80., -800.],\n              [-0.1, 1.9, -10., 190.],\n              [0.6, -0.6, 60., -60.],\n          ],\n      ),\n  )\n  def testSimplexInterpolation(self, lattice_sizes, kernel, inputs,\n                               expected_outputs):\n    if self.disable_all:\n      return\n\n    kernel = tf.constant(kernel, dtype=tf.float32)\n    inputs = tf.constant(inputs, dtype=tf.float32)\n    units = int(kernel.shape[1])\n    model = keras.models.Sequential([\n        ll.Lattice(\n            lattice_sizes,\n            units=units,\n            interpolation=\"simplex\",\n            kernel_initializer=keras.initializers.Constant(kernel),\n        ),\n    ])\n    outputs = model.predict(inputs)\n    self.assertAllClose(outputs, expected_outputs)\n\n  @parameterized.parameters(\n      (\n          [2, 2],\n          [\n              [0., 0.],\n              [1., 1.],\n              [0., 0.],\n              [2., 10.],\n          ],\n          None,\n          None,\n          0.0, 1.0,\n          [\n              [0., 0.],\n              [1., 1.],\n              [0., 0.],\n              [1., 1.],\n          ],\n      ),\n      (\n          [2, 2],\n          [\n              [0., 0.],\n              [1., 1.],\n              [0., 0.],\n              [2., 10.],\n          ],\n          [(0, 1, 1)],\n          None,\n          0.0, 1.0,\n          [\n              [0.0, 0.0],\n              [0.5, 0.1],\n              [0.0, 0.0],\n              [1.0, 1.0],\n          ],\n      ),\n  )\n  def testFinalizeConstraints(self, lattice_sizes, kernel, edgeworth_trusts,\n                              trapezoid_trusts, output_min, output_max,\n                              expected_output):\n    if self.disable_all:\n      return\n\n    kernel = tf.constant(kernel, dtype=tf.float32)\n    units = int(kernel.shape[1])\n    layer = ll.Lattice(\n        lattice_sizes,\n        units=units,\n        monotonicities=[1] * len(lattice_sizes),\n        edgeworth_trusts=edgeworth_trusts,\n        trapezoid_trusts=trapezoid_trusts,\n        output_min=output_min,\n        output_max=output_max,\n        kernel_initializer=keras.initializers.Constant(kernel),\n    )\n    layer.build(input_shape=(None, units, len(lattice_sizes)))\n    output = layer.finalize_constraints()\n    self.assertAllClose(output, expected_output)\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/linear_layer.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Layer which represents linear function. See class level comment.\n\nThis layer applies a linear transformation to the input tensor with an optional\nbias term. It supports monotonicity, monotonic dominance and fixed-norm\nconstraints.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nfrom . import linear_lib\nfrom . import utils\n\nLINEAR_LAYER_KERNEL_NAME = \"linear_layer_kernel\"\nLINEAR_LAYER_BIAS_NAME = \"linear_layer_bias\"\n\n\nclass Linear(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Layer which represents linear function.\n\n  Monotonicity can be specified for any input dimension in which case learned\n  weight for that dimension is guaranteed to be either non negative for\n  increasing or non positive for decreasing monotonicity.\n\n  Monotonic dominance can be specified for any pair of dimensions referred to as\n  *dominant* and *weak* dimensions such that the effect (slope) in the direction\n  of the *dominant* dimension to be greater than that of the *weak* dimension\n  for any point. Both dominant and weak dimensions must be increasing.\n\n  Range dominance can be specified for any pair of *dominant* and *weak*\n  dimensions such that the range of possible outputs to be greater if one varies\n  the *dominant* dimension than if one varies the *weak* dimension for any\n  point. We require the slope of the *dominant* dimension scaled by its input\n  range to be greater than the slope of the *weak* dimension similarly scaled by\n  its input range. Both dimensions must have the same direction of monotonicity\n  and their input min and max must be provided.\n\n  Weights can be constrained to have a fixed norm.\n\n  Input shape:\n    - if `units == 1`: tensor of shape: `(batch_size, num_input_dims)`.\n    - if `units > 1`: tensor of shape: `(batch_size, units, num_input_dims)`\n\n  Output shape:\n  Rank-2 tensor with shape: (batch_size, units)\n\n  Attributes:\n    - All `__init__ `arguments.\n    kernel: layer's kernel.\n    bias: layer's bias. Only available if `use_bias == True`.\n\n  Example:\n\n  ```python\n  layer = tfl.layers.Linear(\n      num_input_dims=8,\n      # Monotonicity constraints can be defined per dimension or for all dims.\n      monotonicities='increasing',\n      use_bias=True,\n      # You can force the L1 norm to be 1. Since this is a monotonic layer,\n      # the coefficients will sum to 1, making this a \"weighted average\".\n      normalization_order=1)\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               num_input_dims,\n               units=1,\n               monotonicities=None,\n               monotonic_dominances=None,\n               range_dominances=None,\n               input_min=None,\n               input_max=None,\n               use_bias=True,\n               normalization_order=None,\n               kernel_initializer=\"random_uniform\",\n               bias_initializer=\"random_uniform\",\n               kernel_regularizer=None,\n               bias_regularizer=None,\n               **kwargs):\n    \"\"\"initializes an instance of `Linear`.\n\n    Args:\n      num_input_dims: Number of input dimensions.\n      units: Output dimension of the layer.\n      monotonicities: None or list or tuple of length 'num_input_dims' of\n        {'decreasing', 'none', 'increasing', -1, 0, 1} which specifies if the\n        model output should be monotonic in corresponding feature, using\n        'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or\n        -1 to indicate decreasing monotonicity and 'none' or 0 to indicate no\n        monotonicity constraints. In case of decreasing monotonicity\n        corresponding weight will be constrained to be non positive, in case of\n        increasing non-negative. Instead of a list or tuple single value can be\n        specified to indicate the monotonicity constraint across all dimensions.\n      monotonic_dominances: None or list of two-element tuples. First element is\n        the index of the dominant dimension. Second element is the index of the\n        weak dimension.\n      range_dominances: None or list of two-element tuples. First element is the\n        index of the dominant dimension. Second element is the index of the weak\n        dimension. Both dominant and weak dimensions must have input_min and\n        input_max set.\n      input_min: None of list or tuple of length 'num_input_dims' of either\n        'none' or float which specifies the minimum value to clip by for each\n        dimension.\n      input_max: None of list or tuple of length 'num_input_dims' of either\n        'none' or float which specifies the maximum value to clip by for each\n        dimension.\n      use_bias: Whether linear function has bias.\n      normalization_order: If specified learned weights will be adjusted to have\n        norm 1. Norm will be computed by: `tf.norm(tensor,\n        ord=normalization_order)`.\n      kernel_initializer: Any keras initializer to be applied to kernel.\n      bias_initializer: Any keras initializer to be applied to bias. Only valid\n        if `use_bias == True`.\n      kernel_regularizer: None or single element or list of any Keras\n        regularizer objects.\n      bias_regularizer: None or single element or list of any Keras regularizer\n        objects.\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: if monotonicity specified incorrectly.\n    \"\"\"\n    super(Linear, self).__init__(**kwargs)\n\n    self.num_input_dims = num_input_dims\n    self.units = units\n\n    if isinstance(monotonicities, list) or isinstance(monotonicities, tuple):\n      self.monotonicities = list(monotonicities)\n    elif monotonicities is not None:\n      self.monotonicities = [monotonicities] * self.num_input_dims\n    else:\n      self.monotonicities = [0] * self.num_input_dims\n    self.monotonic_dominances = monotonic_dominances\n    self.range_dominances = range_dominances\n    self.input_min = input_min\n    self.input_max = input_max\n    # Verify hyperparameters after converting monotonicities to list because\n    # internally everything expects monotonicites to be list or tuple rather\n    # than single element.\n    linear_lib.verify_hyperparameters(\n        num_input_dims=self.num_input_dims, monotonicities=self.monotonicities)\n\n    self.use_bias = use_bias\n    self.normalization_order = normalization_order\n    self.kernel_initializer = keras.initializers.get(kernel_initializer)\n    if use_bias:\n      self.bias_initializer = keras.initializers.get(bias_initializer)\n\n    self.kernel_regularizer = []\n    if kernel_regularizer:\n      if callable(kernel_regularizer):\n        kernel_regularizer = [kernel_regularizer]\n      for reg in kernel_regularizer:\n        self.kernel_regularizer.append(keras.regularizers.get(reg))\n    self.bias_regularizer = []\n    if bias_regularizer:\n      if callable(bias_regularizer):\n        bias_regularizer = [bias_regularizer]\n      for reg in bias_regularizer:\n        self.bias_regularizer.append(keras.regularizers.get(reg))\n\n    if units == 1:\n      input_shape = (None, num_input_dims)\n    else:\n      input_shape = (None, units, num_input_dims)\n    self.input_spec = keras.layers.InputSpec(\n        dtype=self.dtype, shape=input_shape)\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\n\n    Args:\n      input_shape: Must be: (batch_size, num_input_dims) if units == 1, or\n        (batch_size, units, num_input_dims) if units > 1.\n\n    Raises:\n      ValueError: If shape is invalid.\n    \"\"\"\n    linear_lib.verify_hyperparameters(\n        num_input_dims=self.num_input_dims,\n        units=self.units,\n        input_shape=input_shape)\n\n    if (any(self.monotonicities) or self.monotonic_dominances or\n        self.range_dominances or self.normalization_order):\n      constraints = LinearConstraints(\n          monotonicities=self.monotonicities,\n          monotonic_dominances=self.monotonic_dominances,\n          range_dominances=self.range_dominances,\n          input_min=self.input_min,\n          input_max=self.input_max,\n          normalization_order=self.normalization_order)\n    else:\n      constraints = None\n\n    if not self.kernel_regularizer:\n      kernel_reg = None\n    elif len(self.kernel_regularizer) == 1:\n      kernel_reg = self.kernel_regularizer[0]\n    else:\n      # Keras interface assumes only one regularizer, so summ all regularization\n      # losses which we have.\n      kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer])\n\n    self.kernel = self.add_weight(\n        LINEAR_LAYER_KERNEL_NAME,\n        # 1 column matrix rather than verctor for matrix multiplication.\n        shape=[self.num_input_dims, self.units],\n        initializer=self.kernel_initializer,\n        regularizer=kernel_reg,\n        constraint=constraints,\n        dtype=self.dtype)\n\n    if self.use_bias:\n      if not self.bias_regularizer:\n        bias_reg = None\n      elif len(self.bias_regularizer) == 1:\n        bias_reg = self.bias_regularizer[0]\n      else:\n        bias_reg = lambda x: tf.add_n([r(x) for r in self.bias_regularizer])\n      self.bias = self.add_weight(\n          LINEAR_LAYER_BIAS_NAME,\n          shape=[] if self.units == 1 else [self.units],\n          initializer=self.bias_initializer,\n          regularizer=bias_reg,\n          constraint=None,\n          dtype=self.dtype)\n\n    input_min = utils.canonicalize_input_bounds(self.input_min)\n    input_max = utils.canonicalize_input_bounds(self.input_max)\n    if ((input_min and input_min.count(None) < len(input_min)) or\n        (input_max and input_max.count(None) < len(input_max))):\n      lower_bounds = [val if val is not None else -np.inf\n                      for val in input_min or [None] * self.num_input_dims]\n      upper_bounds = [val if val is not None else np.inf\n                      for val in input_max or [None] * self.num_input_dims]\n      self.clip_value_min = tf.constant(lower_bounds, dtype=self.dtype)\n      self.clip_value_max = tf.constant(upper_bounds, dtype=self.dtype)\n    else:\n      self.clip_value_min = None\n      self.clip_value_max = None\n\n    super(Linear, self).build(input_shape)\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    if self.clip_value_min is not None and self.clip_value_max is not None:\n      inputs = tf.clip_by_value(inputs,\n                                clip_value_min=self.clip_value_min,\n                                clip_value_max=self.clip_value_max)\n\n    if self.units == 1:\n      result = tf.matmul(inputs, self.kernel)\n    else:\n      result = tf.reduce_sum(inputs * tf.transpose(self.kernel), axis=-1)\n    if self.use_bias:\n      result += self.bias\n    return result\n\n  def compute_output_shape(self, input_shape):\n    \"\"\"Standard Keras compute_output_shape() method.\"\"\"\n    del input_shape\n    return [None, self.units]\n\n  def get_config(self):\n    \"\"\"Standard Keras get_config() method.\"\"\"\n    config = {\n        \"num_input_dims\": self.num_input_dims,\n        \"units\": self.units,\n        \"monotonicities\": self.monotonicities,\n        \"use_bias\": self.use_bias,\n        \"normalization_order\": self.normalization_order,\n        \"monotonic_dominances\": self.monotonic_dominances,\n        \"range_dominances\": self.range_dominances,\n        \"input_min\": self.input_min,\n        \"input_max\": self.input_max,\n        \"kernel_initializer\":\n            keras.initializers.serialize(\n                self.kernel_initializer, use_legacy_format=True),\n        \"kernel_regularizer\": [\n            keras.regularizers.serialize(r, use_legacy_format=True)\n            for r in self.kernel_regularizer\n        ],\n    }  # pyformat: disable\n    if self.use_bias:\n      config[\"bias_initializer\"] = keras.initializers.serialize(\n          self.bias_initializer, use_legacy_format=True\n      )\n      config[\"bias_regularizer\"] = [\n          keras.regularizers.serialize(r, use_legacy_format=True)\n          for r in self.bias_regularizer\n      ]\n\n    config.update(super(Linear, self).get_config())\n    return config\n\n  # Default eps is bigger than one for other layers because normalization is\n  # prone to numerical errors.\n  def assert_constraints(self, eps=1e-4):\n    \"\"\"Asserts that weights satisfy all constraints.\n\n    In graph mode builds and returns list of assertion ops.\n    In eager mode directly executes assertions.\n\n    Args:\n      eps: Allowed constraints violation.\n\n    Returns:\n      List of assertion ops in graph mode or immediately asserts in eager mode.\n    \"\"\"\n    return linear_lib.assert_constraints(\n        weights=self.kernel,\n        monotonicities=utils.canonicalize_monotonicities(self.monotonicities),\n        monotonic_dominances=self.monotonic_dominances,\n        range_dominances=self.range_dominances,\n        input_min=utils.canonicalize_input_bounds(self.input_min),\n        input_max=utils.canonicalize_input_bounds(self.input_max),\n        normalization_order=self.normalization_order,\n        eps=eps)\n\n\nclass LinearConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Applies monotonicity constraints and normalization to TFL Linear layer.\n\n  Monotonicity is specified per input dimension in which case learned weight for\n  those dimensions is guaranteed to be either non negative for increasing or non\n  positive for decreasing monotonicity.\n\n  Monotonic dominance can be specified for any pair of dimensions referred to as\n  *dominant* and *weak* dimensions such that the effect (slope) in the direction\n  of the *dominant* dimension to be greater than that of the *weak* dimension\n  for any point. Both dominant and weak dimensions must be increasing.\n\n  Range dominance can be specified for any pair of *dominant* and *weak*\n  dimensions such that the range of possible outputs to be greater if one varies\n  the *dominant* dimension than if one varies the *weak* dimension for any\n  point. We require the slope of the *dominant* dimension scaled by its input\n  range to be greater than the slope of the *weak* dimension similarly scaled by\n  its input range. Both dimensions must have the same direction of monotonicity\n  and their input min and max must be provided.\n\n  Weights can be constrained to have norm 1.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, monotonicities, monotonic_dominances=None,\n               range_dominances=None, input_min=None, input_max=None,\n               normalization_order=None):\n    \"\"\"initializes an instance of `LinearConstraints`.\n\n    Args:\n      monotonicities: Same meaning as corresponding parameter of `Linear`.\n      monotonic_dominances: Same meaning as corresponding parameter of `Linear`.\n      range_dominances: Same meaning as corresponding parameter of `Linear`.\n      input_min: Same meaning as corresponding parameter of `Linear`.\n      input_max: Same meaning as corresponding parameter of `Linear`.\n      normalization_order: Same meaning as corresponding parameter of `Linear`.\n    \"\"\"\n    linear_lib.verify_hyperparameters(monotonicities=monotonicities,\n                                      monotonic_dominances=monotonic_dominances,\n                                      range_dominances=range_dominances,\n                                      input_min=input_min,\n                                      input_max=input_max)\n    self.monotonicities = monotonicities\n    self.monotonic_dominances = monotonic_dominances\n    self.range_dominances = range_dominances\n    self.input_min = input_min\n    self.input_max = input_max\n    self.normalization_order = normalization_order\n\n  def __call__(self, w):\n    \"\"\"Applies constraints to w.\n\n    Args:\n      w: Tensor which represents weights of TFL linear layer. Must have shape:\n        `(len(self.monotonicities), 1)`.\n\n    Raises:\n      ValueError: if shape of `w` is not `(len(self.monotonicities), 1)`.\n\n    Returns:\n      Tensor `w` with monotonicity constraints and normalization applied to it.\n    \"\"\"\n    return linear_lib.project(\n        weights=w,\n        monotonicities=utils.canonicalize_monotonicities(self.monotonicities),\n        monotonic_dominances=self.monotonic_dominances,\n        range_dominances=self.range_dominances,\n        input_min=utils.canonicalize_input_bounds(self.input_min),\n        input_max=utils.canonicalize_input_bounds(self.input_max),\n        normalization_order=self.normalization_order)\n\n  def get_config(self):\n    \"\"\"Standard Keras get_config() method.\"\"\"\n    return {\n        \"monotonicities\": self.monotonicities,\n        \"monotonic_dominances\": self.monotonic_dominances,\n        \"range_doinances\": self.range_dominances,\n        \"input_min\": self.input_min,\n        \"input_max\": self.input_max,\n        \"normalization_order\": self.normalization_order\n    }  # pyformat: disable\n"
  },
  {
    "path": "tensorflow_lattice/python/linear_lib.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implementation of algorithms required for Linear layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom . import internal_utils\nfrom . import utils\nimport six\nimport tensorflow as tf\n\n_NORMALIZATION_EPS = 1e-8\n\n\ndef project(weights,\n            monotonicities,\n            monotonic_dominances=None,\n            range_dominances=None,\n            input_min=None,\n            input_max=None,\n            normalization_order=None):\n  \"\"\"Applies constraints to weights.\n\n  Args:\n    weights: Tensor which represents weights of TFL linear layer. Must have\n      shape [len(monotonicities), units].\n    monotonicities: List or tuple of same length as number of elements in\n      'weights' of {-1, 0, 1} which represent monotonicity constraints per\n      dimension. -1 stands for decreasing, 0 for no constraints, 1 for\n      increasing.\n    monotonic_dominances: List of two-element tuples. First element is the index\n      of the dominant feature. Second element is the index of the weak feature.\n    range_dominances: List of two-element tuples. First element is the index of\n      the dominant feature. Second element is the index of the weak feature.\n    input_min: List or tuple of length same length as number of elements in\n      'weights' of either None or float to compute input range for range\n      dominance projection.\n    input_max: List or tuple of length same length as number of elements in\n      'weights' of either None or float to compute input range for range\n      dominance projection.\n    normalization_order: If specified weights will be adjusted to have norm 1.\n      Norm will be computed by: `tf.norm(tensor, ord=normalization_order)`.\n\n  Raises:\n    ValueError: If shape of weights is not `(len(monotonicities), units)`.\n\n  Returns:\n    'weights' with monotonicity constraints and normalization applied to it.\n  \"\"\"\n  verify_hyperparameters(\n      weights_shape=weights.shape,\n      monotonicities=monotonicities,\n      monotonic_dominances=monotonic_dominances,\n      range_dominances=range_dominances,\n      input_min=input_min,\n      input_max=input_max)\n  if any(monotonicities):\n    if 1 in monotonicities:\n      inverted_increasing_mask = tf.constant(\n          value=[0.0 if m == 1 else 1.0 for m in monotonicities],\n          dtype=weights.dtype,\n          shape=(weights.shape[0], 1))\n      # Multiplying by this mask will keep non monotonic dims same and will\n      # set monotonic dims to 0.0. Later by taking maximum with this product\n      # we'll essentially take maximumum of monotonic dims with 0.0.\n      weights = tf.maximum(weights, weights * inverted_increasing_mask)\n\n    if -1 in monotonicities:\n      inverted_decreasing_mask = tf.constant(\n          value=[0.0 if m == -1 else 1.0 for m in monotonicities],\n          dtype=weights.dtype,\n          shape=(weights.shape[0], 1))\n      weights = tf.minimum(weights, weights * inverted_decreasing_mask)\n\n  if monotonic_dominances:\n    monotonic_dominances = [(j, i) for i, j in monotonic_dominances]\n    weights = internal_utils.approximately_project_categorical_partial_monotonicities(\n        weights, monotonic_dominances)\n\n  if range_dominances:\n    range_dominances = [(j, i) for i, j in range_dominances]\n    scalings = [-1.0 if m == -1 else 1.0 for m in monotonicities]\n    for dim, (lower, upper) in enumerate(zip(input_min, input_max)):\n      if lower is not None and upper is not None:\n        scalings[dim] *= upper - lower\n    scalings = tf.constant(\n        scalings, dtype=weights.dtype, shape=(weights.shape[0], 1))\n    weights *= scalings\n    weights = internal_utils.approximately_project_categorical_partial_monotonicities(\n        weights, range_dominances)\n    weights /= scalings\n\n  if normalization_order:\n    norm = tf.norm(weights, axis=0, ord=normalization_order)\n    norm = tf.where(norm < _NORMALIZATION_EPS, 1.0, norm)\n    weights = weights / norm\n\n  return weights\n\n\ndef assert_constraints(weights,\n                       monotonicities,\n                       monotonic_dominances,\n                       range_dominances,\n                       input_min,\n                       input_max,\n                       normalization_order,\n                       eps=1e-4):\n  \"\"\"Asserts that weights satisfy constraints.\n\n  Args:\n    weights: Weights of Linear layer.\n    monotonicities: List or tuple of same length as number of elements in\n      'weights' of {-1, 0, 1} which represent monotonicity constraints per\n      dimension. -1 stands for decreasing, 0 for no constraints, 1 for\n      increasing.\n    monotonic_dominances: List of two-element tuple. First element is the index\n      of the dominant feature. Second element is the index of the weak feature.\n    range_dominances: List of two-element tuples. First element is the index of\n      the dominant feature. Second element is the index of the weak feature.\n    input_min: List or tuple of length same length as number of elements in\n      'weights' of either None or float which specifies the minimum value to\n      clip by.\n    input_max: List or tuple of length same length as number of elements in\n      'weights' of either None or float which specifies the maximum value to\n      clip by.\n    normalization_order: Whether weights have to have norm 1. Norm will be\n      computed by: `tf.norm(tensor, ord=normalization_order)`.\n    eps: Allowed constraints violation.\n\n  Returns:\n    List of assetion ops in graph mode or directly executes assertions in eager\n    mode.\n  \"\"\"\n  asserts = []\n  if any(monotonicities):\n    # Create constant specifying shape explicitly because otherwise due to\n    # weights shape ending with dimesion of size 1 broadcasting will hurt us.\n    monotonicities_constant = tf.constant(\n        monotonicities, shape=(weights.shape[0], 1), dtype=weights.dtype)\n    diff = tf.reduce_min(weights * monotonicities_constant)\n    asserts.append(\n        tf.Assert(\n            diff >= -eps,\n            data=[\n                \"Monotonicity violation\", \"Monotonicities:\", monotonicities,\n                \"Min monotonicity diff:\", diff, \"Epsilon:\", eps, \"Weights:\",\n                weights\n            ],\n            summarize=weights.shape[0]))\n\n  for dominant_dim, weak_dim in monotonic_dominances or []:\n    diff = tf.reduce_min(weights[dominant_dim] - weights[weak_dim])\n    asserts.append(\n        tf.Assert(\n            diff >= -eps,\n            data=[\n                \"Monotonic dominance violation\", \"Dominant dim:\", dominant_dim,\n                \"Weak dim:\", weak_dim, \"Epsilon:\", eps, \"Weights:\", weights\n            ],\n            summarize=weights.shape[0]))\n\n  if range_dominances:\n    scalings = [-1.0 if m == -1 else 1.0 for m in monotonicities]\n    for dim, (lower, upper) in enumerate(zip(input_min, input_max)):\n      if lower is not None and upper is not None:\n        scalings[dim] *= upper - lower\n    for dominant_dim, weak_dim in range_dominances:\n      diff = tf.reduce_min(scalings[dominant_dim] * weights[dominant_dim] -\n                           scalings[weak_dim] * weights[weak_dim])\n      asserts.append(\n          tf.Assert(\n              diff >= -eps,\n              data=[\n                  \"Range dominance violation\", \"Dominant dim:\", dominant_dim,\n                  \"Weak dim:\", weak_dim, \"Epsilon:\", eps, \"Weights:\", weights,\n                  \"Scalings:\", scalings\n              ],\n              summarize=weights.shape[0]))\n\n  if normalization_order:\n    norm = tf.norm(weights, axis=0, ord=normalization_order)\n    asserts.append(\n        # Norm can be either 0.0 or 1.0, because if all weights are close to 0.0\n        # we can't scale them to get norm 1.0.\n        tf.Assert(\n            tf.logical_or(\n                tf.abs(norm - 1.0) < eps,\n                tf.abs(norm) < _NORMALIZATION_EPS),\n            data=[\n                \"Normalization order violation\", \"Norm:\", norm, \"Epsilon:\", eps,\n                \"Weights:\", weights\n            ],\n            summarize=weights.shape[0]))\n  return asserts\n\n\ndef verify_hyperparameters(num_input_dims=None,\n                           units=None,\n                           input_shape=None,\n                           monotonicities=None,\n                           monotonic_dominances=None,\n                           range_dominances=None,\n                           input_min=None,\n                           input_max=None,\n                           weights_shape=None):\n  \"\"\"Verifies that all given hyperparameters are consistent.\n\n  This function does not inspect weights themselves. Only their shape. Use\n  `assert_constraints()` to assert actual weights against constraints.\n\n  Unlike linear layer itself this function requires monotonicites to be\n  specified via list or tuple rather than via single element because that's how\n  monotonicites are stored internaly.\n\n  See `tfl.layers.Linear` Layer class level comment for detailed description of\n  arguments.\n\n  Args:\n    num_input_dims: None or number of input dimensions.\n    units: Units hyperparameter of Linear layer.\n    input_shape: Shape of layer input.\n    monotonicities: List or tuple of same length as number of elements in\n      `weights` of {-1, 0, 1} which represent monotonicity constraints per\n      dimension. -1 stands for decreasing, 0 for no constraints, 1 for\n      increasing.\n    monotonic_dominances: List of two-element tuples. First element is the index\n      of the dominant feature. Second element is the index of the weak feature.\n    range_dominances: List of two-element tuples. First element is the index of\n      the dominant feature. Second element is the index of the weak feature.\n    input_min: List or tuple of length same length as number of elements in\n      'weights' of either None or float which specifies the minimum value to\n      clip by.\n    input_max: List or tuple of length same length as number of elements in\n      'weights' of either None or float which specifies the maximum value to\n      clip by.\n    weights_shape: None or shape of tensor which represents weights of Linear\n      layer.\n\n  Raises:\n    ValueError: If something is inconsistent.\n  \"\"\"\n  # It also raises errors if monotonicities specified incorrectly.\n  monotonicities = utils.canonicalize_monotonicities(monotonicities)\n  input_min = utils.canonicalize_input_bounds(input_min)\n  input_max = utils.canonicalize_input_bounds(input_max)\n\n  if monotonicities is not None and num_input_dims is not None:\n    if len(monotonicities) != num_input_dims:\n      raise ValueError(\"Number of elements in 'monotonicities' must be equal to\"\n                       \" num_input_dims. monotoniticites: %s, \"\n                       \"len(monotonicities): %d, num_input_dims: %d\" %\n                       (monotonicities, len(monotonicities), num_input_dims))\n\n  if weights_shape is not None:\n    if len(weights_shape) != 2:\n      raise ValueError(\"Expect weights to be a rank 2 tensor. Weights shape: \"\n                       \"%s\" % (weights_shape,))\n    if monotonicities is not None and weights_shape[0] != len(monotonicities):\n      raise ValueError(\"Number of elements in 'monotonicities' does not \"\n                       \"correspond to number of weights. Weights shape: %s, \"\n                       \"monotonicities: %s\" % (weights_shape, monotonicities))\n    if input_min is not None and weights_shape[0] != len(input_min):\n      raise ValueError(\n          \"Number of elements in 'input_min' does not correspond \"\n          \"to number of weights. Weights shape: %s, input_min: %s\" %\n          (weights_shape, input_min))\n    if input_max is not None and weights_shape[0] != len(input_max):\n      raise ValueError(\n          \"Number of elements in 'input_max' does not correspond \"\n          \"to number of weights. Weights shape: %s, input_max: %s\" %\n          (weights_shape, input_max))\n\n  if input_shape is not None:\n    assert units is not None and num_input_dims is not None\n    if (units > 1 and\n        (len(input_shape) != 3 or input_shape[1] != units or\n         input_shape[2] != num_input_dims)):\n      raise ValueError(\"'input_shape' must be of rank three and number of \"\n                       \"elements of second and third dimensions must be \"\n                       \"equal to 'units' and 'num_input_dims' respectively. \"\n                       \"'input_shape': \" + str(input_shape) + \"'units': \" +\n                       str(units) + \"'num_input_dims': \" + str(num_input_dims))\n    elif (units == 1 and\n          (len(input_shape) != 2 or input_shape[1] != num_input_dims)):\n      raise ValueError(\"'input_shape' must be of rank two and number of \"\n                       \"elements of second dimension must be equal to \"\n                       \"'num_input_dims'. 'input_shape': \" + str(input_shape) +\n                       \"'num_input_dims': \" + str(num_input_dims))\n\n  for dim, (lower, upper) in enumerate(zip(input_min or [], input_max or [])):\n    if lower is not None and upper is not None and lower > upper:\n      raise ValueError(\"Cannot have 'input_min' greater than 'input_max'.\"\n                       \"Dimension: %d, input_min[%d]: %f, input_max[%d]: %f\" %\n                       (dim, dim, input_min[dim], dim, input_max[dim]))\n\n  if monotonic_dominances is not None:\n    assert monotonicities is not None\n    num_input_dims = len(monotonicities)\n    dim_pairs = set()\n    for constraint in monotonic_dominances:\n      if len(constraint) != 2:\n        raise ValueError(\"Monotonic dominance constraints must consist of 2 \"\n                         \"elements. Seeing constraint tuple %s\" % (constraint,))\n      dominant_dim, weak_dim = constraint\n      if (dominant_dim >= num_input_dims or weak_dim >= num_input_dims or\n          dominant_dim < 0 or weak_dim < 0):\n        raise ValueError(\"Dimensions constrained by monotonic dominance \"\n                         \"constraints are not within the input dimensions. \"\n                         \"'dims': %s, %s, num_dims: %s\" %\n                         (dominant_dim, weak_dim, num_input_dims))\n      if not isinstance(dominant_dim, int) or not isinstance(weak_dim, int):\n        raise ValueError(\"Monotonic dominance constraint dimensions must be \"\n                         \"integers. Seeing dominant_dim %s and weak_dim %s\" %\n                         (dominant_dim, weak_dim))\n      for dim in [dominant_dim, weak_dim]:\n        if monotonicities[dim] != 1:\n          raise ValueError(\"Monotonic dominance constraint's dimensions must \"\n                           \"be monotonic. Dimension %d is not monotonic.\" %\n                           (dim))\n      if (weak_dim, dominant_dim) in dim_pairs:\n        raise ValueError(\"Cannot have two monotonic dominance constraints on \"\n                         \"the same pair of features conflicting. Features: %d, \"\n                         \"%d\" % (dominant_dim, weak_dim))\n      dim_pairs.add((dominant_dim, weak_dim))\n\n  if range_dominances is not None:\n    assert monotonicities is not None\n    num_input_dims = len(monotonicities)\n    dim_pairs = set()\n    for constraint in range_dominances:\n      if len(constraint) != 2:\n        raise ValueError(\"Range dominance constraints must consist of 2 \"\n                         \"elements. Seeing constraint tuple %s\" % (constraint,))\n      dominant_dim, weak_dim = constraint\n      if (dominant_dim >= num_input_dims or weak_dim >= num_input_dims or\n          dominant_dim < 0 or weak_dim < 0):\n        raise ValueError(\"Dimensions constrained by range dominance \"\n                         \"constraints are not within the input dimensions. \"\n                         \"'dims': %s, %s, num_dims: %s\" %\n                         (dominant_dim, weak_dim, num_input_dims))\n      if not isinstance(dominant_dim, int) or not isinstance(weak_dim, int):\n        raise ValueError(\"Range dominance constraint dimensions must be \"\n                         \"integers. Seeing dominant_dim %s and weak_dim %s\" %\n                         (dominant_dim, weak_dim))\n      if (monotonicities[dominant_dim] != monotonicities[weak_dim] or\n          monotonicities[dominant_dim] == 0):\n        raise ValueError(\"Range dominance constraint's dimensions must have \"\n                         \"the same direction of monotonicity. Dimension %d is \"\n                         \"%d. Dimension %d is %d.\" %\n                         (dominant_dim, monotonicities[dominant_dim], weak_dim,\n                          monotonicities[weak_dim]))\n      for dim in [dominant_dim, weak_dim]:\n        if input_min is None or input_min[dim] is None:\n          raise ValueError(\"Range dominance constraint's dimensions must \"\n                           \"have `input_min` set. Dimension %d is not set.\" %\n                           (dim))\n        if input_max is None or input_max[dim] is None:\n          raise ValueError(\"Range dominance constraint's dimensions must \"\n                           \"have `input_max` set. Dimension %d is not set.\" %\n                           (dim))\n      if (weak_dim, dominant_dim) in dim_pairs:\n        raise ValueError(\"Cannot have two range dominance constraints on the \"\n                         \"same pair of features conflicting. Features: %d, %d\" %\n                         (dominant_dim, weak_dim))\n      dim_pairs.add((dominant_dim, weak_dim))\n\n  if range_dominances is not None and monotonic_dominances is not None:\n    monotonic_dominance_dims = set()\n    for dims in monotonic_dominances:\n      for dim in dims:\n        monotonic_dominance_dims.add(dim)\n    for dims in range_dominances:\n      for dim in dims:\n        if dim in monotonic_dominance_dims:\n          raise ValueError(\"Cannot have both monotonic and range dominance \"\n                           \"constraints specified on the same dimension. \"\n                           \"Dimension %d is set by both.\" % (dim))\n"
  },
  {
    "path": "tensorflow_lattice/python/linear_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for Tensorflow Lattice linear layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import linear_layer as linl\nfrom tensorflow_lattice.python import test_utils\nfrom tensorflow_lattice.python import utils\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n_DISABLE_ALL = False\n_LOSS_EPS = 0.0001\n_SMALL_EPS = 1e-6\n\n\nclass LinearTest(parameterized.TestCase, tf.test.TestCase):\n  \"\"\"Tests for TFL linear layer.\"\"\"\n\n  def setUp(self):\n    super(LinearTest, self).setUp()\n    keras.utils.set_random_seed(42)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  def _ScaterXUniformly(self, num_points, num_dims, input_min, input_max):\n    \"\"\"Generates num_points num_dims-dimensional points within given range.\"\"\"\n    np.random.seed(41)\n    x = []\n    for _ in range(num_points):\n      point = [\n          np.random.random() * (input_max - input_min) + input_min\n          for _ in range(num_dims)\n      ]\n      x.append(np.asarray(point))\n    if num_dims == 1:\n      x.sort()\n    return x\n\n  def _TwoDMeshGrid(self, num_points, num_dims, input_min, input_max):\n    \"\"\"Mesh grid for visualisation of 3-d surfaces via pyplot.\"\"\"\n    if num_dims != 2:\n      raise ValueError(\"2-d mesh grid can be created only for 2-d data. Given: \"\n                       \"%d.\" % num_dims)\n    return test_utils.two_dim_mesh_grid(\n        num_points=num_points,\n        x_min=input_min,\n        y_min=input_min,\n        x_max=input_max,\n        y_max=input_max)\n\n  def _GenLinearFunction(self, weights, bias=0.0, noise=None):\n    \"\"\"Returns python function which computes linear function.\"\"\"\n\n    def Linear(x):\n      if len(x) != len(weights):\n        raise ValueError(\"X and weights have different number of elements. \"\n                         \"X: \" + str(x) + \"; weights: \" + str(weights))\n      result = bias\n      if noise:\n        result += noise(x)\n      for (i, y) in enumerate(x):\n        result += weights[i] * y\n      return result\n\n    return Linear\n\n  def _SinPlusXPlusD(self, x):\n    return math.sin(x[0]) + x[0] / 3.0 + 3.0\n\n  def _SetDefaults(self, config):\n    config.setdefault(\"monotonicities\", None)\n    config.setdefault(\"monotonic_dominances\", None)\n    config.setdefault(\"range_dominances\", None)\n    config.setdefault(\"clip_min\", None)\n    config.setdefault(\"clip_max\", None)\n    config.setdefault(\"use_bias\", False)\n    config.setdefault(\"normalization_order\", None)\n    config.setdefault(\"kernel_init_constant\", 0.0)\n    config.setdefault(\"bias_init_constant\", 0.0)\n    config.setdefault(\"kernel_regularizer\", None)\n    config.setdefault(\"bias_regularizer\", None)\n    config.setdefault(\"allowed_constraints_violation\", 1e-6)\n    config.setdefault(\"units\", 1)\n    config.setdefault(\"unit_index\", 0)\n    return config\n\n  def _GetTrainingInputsAndLabels(self, config):\n    \"\"\"Generates training inputs and labels.\n\n    Args:\n      config: Dict with config for this unit test.\n\n    Returns:\n      Tuple `(training_inputs, training_labels, raw_training_inputs)` where\n        `training_inputs` and `training_labels` are data for training.\n    \"\"\"\n    raw_training_inputs = config[\"x_generator\"](\n        num_points=config[\"num_training_records\"],\n        num_dims=config[\"num_input_dims\"],\n        input_min=config[\"input_min\"],\n        input_max=config[\"input_max\"])\n\n    if isinstance(raw_training_inputs, tuple):\n      # This means that raw inputs are 2-d mesh grid. Convert them into list of\n      # 2-d points.\n      training_inputs = list(np.dstack(raw_training_inputs).reshape((-1, 2)))\n    else:\n      training_inputs = raw_training_inputs\n\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n    return training_inputs, training_labels\n\n  def _TrainModel(self, config):\n    \"\"\"Trains model and returns loss.\n\n    Args:\n      config: Layer config internal for this test which specifies params of\n        linear layer to train.\n\n    Returns:\n      Training loss.\n    \"\"\"\n    logging.info(\"Testing config:\")\n    logging.info(config)\n    config = self._SetDefaults(config)\n\n    self._ResetAllBackends()\n\n    training_inputs, training_labels = (\n        self._GetTrainingInputsAndLabels(config))\n    units = config[\"units\"]\n    num_input_dims = config[\"num_input_dims\"]\n    if units > 1:\n      # In order to test multi 'units' linear, replicate inputs 'units' times\n      # and later use just one out of 'units' outputs in order to ensure that\n      # multi 'units' linear trains exactly similar to single 'units' one.\n      training_inputs = [\n          np.tile(np.expand_dims(x, axis=0), reps=[units, 1])\n          for x in training_inputs\n      ]\n      input_shape = (units, num_input_dims)\n    else:\n      input_shape = (num_input_dims,)\n\n    linear_layer = linl.Linear(\n        input_shape=input_shape,\n        num_input_dims=config[\"num_input_dims\"],\n        units=units,\n        monotonicities=config[\"monotonicities\"],\n        monotonic_dominances=config[\"monotonic_dominances\"],\n        range_dominances=config[\"range_dominances\"],\n        input_min=config[\"clip_min\"],\n        input_max=config[\"clip_max\"],\n        use_bias=config[\"use_bias\"],\n        normalization_order=config[\"normalization_order\"],\n        kernel_initializer=keras.initializers.Constant(\n            config[\"kernel_init_constant\"]),\n        bias_initializer=keras.initializers.Constant(\n            config[\"bias_init_constant\"]),\n        kernel_regularizer=config[\"kernel_regularizer\"],\n        bias_regularizer=config[\"bias_regularizer\"],\n        dtype=tf.float32)\n    model = keras.models.Sequential()\n    model.add(linear_layer)\n    # When we use multi-unit linear, we only extract a single unit for testing.\n    if units > 1:\n      unit_index = config[\"unit_index\"]\n      model.add(\n          keras.layers.Lambda(lambda x: x[:, unit_index:unit_index + 1]))\n    optimizer = config[\"optimizer\"](learning_rate=config[\"learning_rate\"])\n    model.compile(loss=keras.losses.mean_squared_error, optimizer=optimizer)\n\n    training_data = (training_inputs, training_labels)\n\n    loss = test_utils.run_training_loop(\n        config=config, training_data=training_data, keras_model=model\n    )\n\n    assetion_ops = linear_layer.assert_constraints(\n        eps=config[\"allowed_constraints_violation\"])\n    if not tf.executing_eagerly() and assetion_ops:\n      tf.compat.v1.keras.backend.get_session().run(assetion_ops)\n    return loss\n\n  def _NegateAndTrain(self, config):\n    \"\"\"Changes monotonicity directions to opposite and trains model.\"\"\"\n    negated_config = dict(config)\n    negated_config[\"y_function\"] = lambda x: -config[\"y_function\"](x)\n    negated_config[\"bias_init_constant\"] = -config[\"bias_init_constant\"]\n    negated_config[\"kernel_init_constant\"] = -config[\"kernel_init_constant\"]\n\n    if isinstance(config[\"monotonicities\"], list):\n      negated_config[\"monotonicities\"] = [\n          -monotonicity for monotonicity in\n          utils.canonicalize_monotonicities(config[\"monotonicities\"])\n      ]\n    else:\n      negated_config[\"monotonicities\"] = -config[\"monotonicities\"]\n\n    negated_loss = self._TrainModel(negated_config)\n    return negated_loss\n\n  @parameterized.parameters((False, 1.623906), (True, 0.456815))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testOneDUnconstrained(self, use_bias, expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 1,\n        \"use_bias\": use_bias,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 400,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 5.0,\n        \"input_max\": 25.0,\n        \"y_function\": self._SinPlusXPlusD,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n\n  @parameterized.parameters((False, 0.881774), (True, 0.441771))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testTwoDUnconstrained(self, use_bias, expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 2,\n        \"use_bias\": use_bias,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 160,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"y_function\": self._GenLinearFunction(\n            weights=[-1.0, 2.0],\n            bias=-2.0,\n            noise=lambda x: math.sin(sum(x)) / 1.0),\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n\n  def testInitializers(self):\n    # Test initializers by trying to fit linear function using 0 iterations.\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 2,\n        \"use_bias\": True,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"kernel_init_constant\": 3.0,\n        \"bias_init_constant\": -2.0,\n        \"y_function\": self._GenLinearFunction(weights=[3.0, 3.0], bias=-2.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=_LOSS_EPS)\n\n  def testAssertConstraints(self):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 4,\n        \"use_bias\": True,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 0,\n        \"normalization_order\": 1,\n        \"monotonicities\": [1] * 4,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"kernel_init_constant\": 0.25,\n        \"bias_init_constant\": -2.0,\n        \"y_function\": self._GenLinearFunction(weights=[0.25] * 4, bias=-2.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=_LOSS_EPS)\n\n    with self.assertRaises(tf.errors.InvalidArgumentError):\n      config[\"normalization_order\"] = 2\n      self._TrainModel(config)\n\n    with self.assertRaises(tf.errors.InvalidArgumentError):\n      # Setting valid normalization order back and instead violating\n      # monotonicity.\n      config[\"normalization_order\"] = 1\n      config[\"monotonicities\"] = [1, 1, -1, 0]\n      self._TrainModel(config)\n\n  @parameterized.parameters((False, 1.623906), (True, 0.456815))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testOneDMonotonicities_MonotonicInput(self, use_bias, expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 1,\n        \"monotonicities\": [1],\n        \"use_bias\": use_bias,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 400,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 5.0,\n        \"input_max\": 25.0,\n        \"y_function\": self._SinPlusXPlusD,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n    self.assertAlmostEqual(loss, self._NegateAndTrain(config), delta=_SMALL_EPS)\n\n  @parameterized.parameters((False, 62.670425), (True, 3.326165))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testOneDMonotonicities_AntiMonotonicInput(self, use_bias, expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 1,\n        \"monotonicities\": [\"increasing\"],\n        \"use_bias\": use_bias,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 400,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 5.0,\n        \"input_max\": 25.0,\n        \"y_function\": lambda x: -self._SinPlusXPlusD(x),\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n    self.assertAlmostEqual(loss, self._NegateAndTrain(config), delta=_SMALL_EPS)\n\n  @parameterized.parameters((1, 2.0), (1, -2.0), (2, 2.0), (2, -2.0))\n  # Expected loss is computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testOneDNormalizationOrder(self, norm_order, weight):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 1,\n        \"monotonicities\": [0],\n        \"normalization_order\": norm_order,\n        \"use_bias\": True,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 0.0,\n        \"input_max\": 5.0,\n        \"y_function\": self._GenLinearFunction(weights=[weight], bias=0.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    # For 1-d case normalization order does not change anything.\n    self.assertAlmostEqual(loss, 1.727717, delta=_LOSS_EPS)\n\n  def testOneDNormalizationOrderZeroWeights(self):\n    if _DISABLE_ALL:\n      return\n    # Normalization is impossible when all weights are 0.0 so weights should not\n    # be affected by it.\n    config = {\n        \"num_input_dims\": 1,\n        \"monotonicities\": [\"none\"],\n        \"normalization_order\": 1,\n        \"use_bias\": True,\n        \"num_training_records\": 128,\n        \"num_training_epoch\": 20,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 0.0,\n        \"input_max\": 5.0,\n        \"y_function\": self._GenLinearFunction(weights=[0.0], bias=0.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.0, delta=_LOSS_EPS)\n\n  @parameterized.parameters(\n      (0.441771, 0),\n      (0.441771, [\"none\", \"none\"]),\n      (2.61706, 1),\n      (2.61706, [\"increasing\", \"increasing\"]),\n      (2.61706, [\"increasing\", \"none\"]),\n      (0.441771, [\"none\", \"increasing\"])\n  )\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testTwoDMonotonicity(self, expected_loss, monotonicities):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 2,\n        \"monotonicities\": monotonicities,\n        \"use_bias\": True,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 160,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"y_function\": self._GenLinearFunction(\n            weights=[-1.0, 2.0],\n            bias=-2.0,\n            noise=lambda x: math.sin(sum(x)) / 1.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n    self.assertAlmostEqual(loss, self._NegateAndTrain(config), delta=_SMALL_EPS)\n\n    multioutput_config = dict(config)\n    units = 3\n    multioutput_config[\"units\"] = units\n    for unit_index in range(units):\n      multioutput_config[\"unit_index\"] = unit_index\n      loss = self._TrainModel(multioutput_config)\n      self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n      self.assertAlmostEqual(\n          loss, self._NegateAndTrain(multioutput_config), delta=_SMALL_EPS)\n\n  @parameterized.parameters(\n      (1, [0.2, 0.3], 0, 0.250532),  # Testing sum of weights < 1.0.\n      (1, [0.2, 0.3], 1, 0.250532),  # Monotonicity does not matter here.\n      (2, [0.2, 0.3], 0, 0.753999),\n      (1, [1.0, 2.0], 0, 5.688659),  # Testing sum of weights > 1.0.\n      (1, [-1.0, 2.0], 0, 4.043515),\n      # With negative weights monotonicity matters.\n      (1, [-1.0, 2.0], 1, 3.433537))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testTwoDNormalizationOrder(self, norm_order, weights, monotonicities,\n                                 expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 2,\n        \"normalization_order\": norm_order,\n        \"monotonicities\": monotonicities,\n        # If normalization order is set then layer will always converges to\n        # extremes if there is no bias or other layers. That's why we always\n        # use bias for normalization order tests.\n        \"use_bias\": True,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 160,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"y_function\": self._GenLinearFunction(\n            weights=weights, noise=lambda x: math.sin(sum(x)) / 10.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n\n  @parameterized.parameters(\n      ([0.5, 0.6, 0.06, 0.07, 0.08], [1, 1, 1, 1, 1], 0.0408642),\n      ([0.5, -0.6, 0.06, -0.07, 0.08], [1, 1, 1, 1, 1], 0.561592),\n      ([0.5, -0.6, 0.06, -0.07, 0.08], [0, 0, 1, 1, 1], 0.047663))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testFiveDAllConstraints(self, weights, monotonicities, expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 5,\n        \"normalization_order\": 1,\n        \"monotonicities\": monotonicities,\n        \"use_bias\": True,\n        \"num_training_records\": 640,\n        \"num_training_epoch\": 160,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 0.0,\n        \"kernel_init_constant\": 0.7,\n        \"input_max\": 4.0,\n        \"y_function\": self._GenLinearFunction(\n            weights=weights, noise=lambda x: math.sin(sum(x)) / 30.0)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n\n  @parameterized.parameters((0.85766, [(0, 1)]),\n                            (1e-13, [(1, 0)]))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testTwoDMonotonicDominance(self, expected_loss, dominances):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 2,\n        \"monotonicities\": [\"increasing\", \"increasing\"],\n        \"monotonic_dominances\": dominances,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 160,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"y_function\": self._GenLinearFunction(weights=[1.0, 2.0])\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n\n  @parameterized.parameters(([(0, 1)], [1, 1, 0], [1.0, 2.0, 3.0], 1.8409),\n                            ([(0, 1)], [-1, -1, 0], [-1.0, -2.0, -3.0], 1.8409),\n                            ([(1, 0)], [1, 1, 0], [1.0, 2.0, 3.0], 0.6567),\n                            ([(1, 0)], [-1, -1, 0], [-1.0, -2.0, -3.0], 0.6567))\n  # Expected losses are computed by running this test. Correctness is verified\n  # manually by looking at visualisation of learned function vs ground truth.\n  def testTwoDRangeDominance(self, dominances, monotonicities, weights,\n                             expected_loss):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 3,\n        \"monotonicities\": monotonicities,\n        \"range_dominances\": dominances,\n        \"clip_min\": [0.0, 0.0, \"none\"],\n        \"clip_max\": (1.0, 4.0, \"none\"),\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 160,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._ScaterXUniformly,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"y_function\": self._GenLinearFunction(weights=weights)\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS)\n\n  @parameterized.parameters(\n      # Standard Keras regularizer:\n      (keras.regularizers.l1_l2(l1=0.01, l2=0.001),),\n      # Tuple of regularizers:\n      ((keras.regularizers.l1_l2(l1=0.01, l2=0.0),\n        keras.regularizers.l1_l2(l1=0.0, l2=0.001)),),\n  )\n  def testRegularizers(self, regularizer):\n    if _DISABLE_ALL:\n      return\n    config = {\n        \"num_input_dims\": 2,\n        \"use_bias\": True,\n        \"num_training_records\": 64,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.5,\n        \"x_generator\": self._TwoDMeshGrid,\n        \"input_min\": 0.0,\n        \"input_max\": 4.0,\n        \"kernel_init_constant\": 2.0,\n        \"bias_init_constant\": 3.0,\n        \"y_function\": self._GenLinearFunction(weights=[2.0, 2.0], bias=3.0),\n        \"kernel_regularizer\": regularizer,\n        \"bias_regularizer\": regularizer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    # This loss is pure regularization loss because initializer matches target\n    # function and there was 0 training epochs.\n    self.assertAlmostEqual(loss, 0.087, delta=_LOSS_EPS)\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/model_info.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Classes defining trained TFL model structure and parameter information.\n\nThis package provides representations and tools for analysis of a trained\nTF Lattice model, e.g. a canned estimator in saved model format.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\n\nclass ModelGraph(\n    collections.namedtuple('ModelGraph', ['nodes', 'output_node'])):\n  \"\"\"Model info and parameter as a graph.\n\n  Note that this is not a TF graph, but rather a graph of python object that\n  describe model structure and parameters.\n\n  Attributes:\n    nodes: List of all the nodes in the model.\n    output_node: The output node of the model.\n  \"\"\"\n\n\nclass InputFeatureNode(\n    collections.namedtuple('InputFeatureNode',\n                           ['name', 'is_categorical', 'vocabulary_list'])):\n  \"\"\"Input features to the model.\n\n  Attributes:\n    name: Name of the input feature.\n    is_categorical: If the feature is categorical.\n    vocabulary_list: Category values for categorical features or None.\n  \"\"\"\n\n\nclass PWLCalibrationNode(\n    collections.namedtuple('PWLCalibrationNode', [\n        'input_node', 'input_keypoints', 'output_keypoints', 'default_input',\n        'default_output'\n    ])):\n  \"\"\"Represetns a PWL calibration layer.\n\n  Attributes:\n    input_node: Input node for the calibration.\n    input_keypoints: Input keypoints for PWL calibration.\n    output_keypoints: Output keypoints for PWL calibration.\n    default_input: Default/missing input value or None.\n    default_output: Default/missing output value or None.\n  \"\"\"\n\n\nclass CategoricalCalibrationNode(\n    collections.namedtuple('CategoricalCalibrationNode',\n                           ['input_node', 'output_values', 'default_input'])):\n  \"\"\"Represetns a categorical calibration layer.\n\n  Attributes:\n    input_node: Input node for the calibration.\n    output_values: Output calibration values. If the calibrated feature has\n      default/missing values, the last value will be for default/missing.\n    default_input: Default/missing input value or None.\n  \"\"\"\n\n\nclass LinearNode(\n    collections.namedtuple('LinearNode',\n                           ['input_nodes', 'coefficients', 'bias'])):\n  \"\"\"Represents a linear layer.\n\n  Attributes:\n    input_nodes: List of input nodes to the linear layer.\n    coefficients: Linear weights.\n    bias: Bias term for the linear layer.\n  \"\"\"\n\n\nclass LatticeNode(\n    collections.namedtuple('LatticeNode', ['input_nodes', 'weights'])):\n  \"\"\"Represetns a lattice layer.\n\n  Attributes:\n    input_nodes: List of input nodes to the lattice layer.\n    weights: Lattice parameters.\n  \"\"\"\n\n\nclass KroneckerFactoredLatticeNode(\n    collections.namedtuple('KroneckerFactoredLatticeNode',\n                           ['input_nodes', 'weights', 'scale', 'bias'])):\n  \"\"\"Represents a kronecker-factored lattice layer.\n\n  Attributes:\n    input_nodes: List of input nodes to the kronecker-factored lattice layer.\n    weights: Kronecker-factored lattice kernel parameters of shape\n      `(1, lattice_sizes, units * dims, num_terms)`.\n    scale: Kronecker-factored lattice scale parameters of shape\n      `(units, num_terms)`.\n    bias: Kronecker-factored lattice bias parameters of shape `(units)`.\n  \"\"\"\n\n\nclass MeanNode(collections.namedtuple('MeanNode', ['input_nodes'])):\n  \"\"\"Represents an averaging layer.\n\n  Attributes:\n    input_nodes: List of input nodes to the average layer.\n  \"\"\"\n"
  },
  {
    "path": "tensorflow_lattice/python/parallel_combination_layer.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"ParallelCombination layer for combining several parallel calibration layers.\n\nThis layer wraps several calibration layers under single ParallelCombination one\nthat can be used by `Sequential` Keras model.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_lattice.python import categorical_calibration_layer\nfrom tensorflow_lattice.python import lattice_layer\nfrom tensorflow_lattice.python import linear_layer\nfrom tensorflow_lattice.python import pwl_calibration_layer\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\n# TODO: Add support for calibrators with units > 1.\nclass ParallelCombination(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Wraps several parallel calibration layers under single one.\n\n  `ParallelCombination` is designed for combning several calibration layers\n  which output goes into single `Lattice` or `Linear` layer in order to be able\n  to use calibration layers within `Sequential` model.\n\n  Difference from `keras.layers.Concatenate` is that last one operates on\n  already built objects and thus cannot be used to group layers for `Sequential`\n  model.\n\n  Input shape:\n    `(batch_size, k)` or list of length `k` of shapes: `(batch_size, 1)` where\n    `k` is a number of associated calibration layers.\n\n  Output shape:\n    `(batch_size, k)` or list of length `k` of shapes: `(batch_size, 1)` where\n    `k` is a number of associated calibration layers. Shape of output depends on\n    `single_output` parameter.\n\n  Attributes:\n    - All `__init__` arguments.\n\n  Example:\n\n  Example usage with a Sequential model:\n\n  ```python\n  model = keras.models.Sequential()\n  combined_calibrators = ParallelCombination()\n  for i in range(num_dims):\n    calibration_layer = PWLCalibration(...)\n    combined_calibrators.append(calibration_layer)\n  model.add(combined_calibrators)\n  model.add(Lattice(...))\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, calibration_layers=None, single_output=True, **kwargs):\n    \"\"\"Initializes an instance of `ParallelCombination`.\n\n    Args:\n      calibration_layers: List of `PWLCalibration` or `CategoricalCalibration`\n        objects or any other layers taking and returning tensor of shape\n        `(batch_size, 1)`.\n      single_output: if True returns output as single tensor of shape\n        `(batch_size, k)`. Otherwise returns list of `k` tensors of shape\n        `(batch_size, 1)`.\n      **kwargs: other args passed to `keras.layers.Layer` initializer.\n    \"\"\"\n    super(ParallelCombination, self).__init__(**kwargs)\n    self.calibration_layers = []\n    for calibration_layer in calibration_layers or []:\n      if not isinstance(calibration_layer, dict):\n        self.calibration_layers.append(calibration_layer)\n      else:\n        # Keras deserialization logic must have explicit acceess to all custom\n        # classes. This is standard way to provide such access.\n        with keras.utils.custom_object_scope({\n            \"Lattice\":\n                lattice_layer.Lattice,\n            \"Linear\":\n                linear_layer.Linear,\n            \"PWLCalibration\":\n                pwl_calibration_layer.PWLCalibration,\n            \"CategoricalCalibration\":\n                categorical_calibration_layer.CategoricalCalibration,\n        }):\n          self.calibration_layers.append(\n              keras.layers.deserialize(\n                  calibration_layer, use_legacy_format=True\n              )\n          )\n    self.single_output = single_output\n\n  def append(self, calibration_layer):\n    \"\"\"Appends new calibration layer to the end.\"\"\"\n    self.calibration_layers.append(calibration_layer)\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    if isinstance(input_shape, list):\n      if len(input_shape) != len(self.calibration_layers):\n        raise ValueError(\"Number of ParallelCombination input tensors does not \"\n                         \"match number of calibration layers. input_shape: %s, \"\n                         \"layers: %s\" % (input_shape, self.calibration_layers))\n    else:\n      if input_shape[1] != len(self.calibration_layers):\n        raise ValueError(\"Second dimension of ParallelCombination input tensor \"\n                         \"does not match number of calibration layers. \"\n                         \"input_shape: %s, layers: %s\" %\n                         (input_shape, self.calibration_layers))\n    super(ParallelCombination, self).build(input_shape)\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    if not isinstance(inputs, list):\n      if len(inputs.shape) != 2:\n        raise ValueError(\"'inputs' is expected to have rank-2. \"\n                         \"Given: %s\" % inputs)\n      inputs = tf.split(inputs, axis=1, num_or_size_splits=inputs.shape[1])\n    if len(inputs) != len(self.calibration_layers):\n      raise ValueError(\"Number of ParallelCombination input tensors does not \"\n                       \"match number of calibration layers. inputs: %s, \"\n                       \"layers: %s\" % (inputs, self.calibration_layers))\n    outputs = [\n        layer(one_d_input)\n        for layer, one_d_input in zip(self.calibration_layers, inputs)\n    ]\n    if self.single_output:\n      return tf.concat(outputs, axis=1)\n    else:\n      return outputs\n\n  def compute_output_shape(self, input_shape):\n    if self.single_output:\n      return tf.TensorShape([None, len(self.calibration_layers)])\n    else:\n      return [tf.TensorShape([None, 1])] * len(self.calibration_layers)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"calibration_layers\": [\n            keras.layers.serialize(layer, use_legacy_format=True)\n            for layer in self.calibration_layers\n        ],\n        \"single_output\": self.single_output,\n    }  # pyformat: disable\n    config.update(super(ParallelCombination, self).get_config())\n    return config\n"
  },
  {
    "path": "tensorflow_lattice/python/parallel_combination_test.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Lattice Layer.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tempfile\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import lattice_layer as ll\nfrom tensorflow_lattice.python import parallel_combination_layer as pcl\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass ParallelCombinationTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(ParallelCombinationTest, self).setUp()\n    self.disable_all = False\n    keras.utils.set_random_seed(42)\n\n  def testParallelCombinationSingleInput(self):\n    if self.disable_all:\n      return\n    all_calibrators = pcl.ParallelCombination()\n    for i in range(3):\n      # Its not typical to use 1-d Lattice layer for calibration, but lets do it\n      # to avoid redundant dependency on PWLCalibration layer.\n      calibrator = ll.Lattice(\n          lattice_sizes=[2], output_min=0.0, output_max=i + 1.0)\n      all_calibrators.append(calibrator)\n\n    # Given output range specified below linear initializer will have lattice to\n    # simply sum up inputs.\n    simple_sum = ll.Lattice(\n        lattice_sizes=[5] * 3,\n        kernel_initializer=\"linear_initializer\",\n        output_min=0.0,\n        output_max=12.0,\n        name=\"SummingLattice\")\n    model = keras.models.Sequential()\n    model.add(all_calibrators)\n    model.add(simple_sum)\n\n    test_inputs = np.asarray([\n        [0.0, 0.0, 0.0],\n        [0.1, 0.2, 0.3],\n        [1.0, 1.0, 1.0],\n    ])\n    predictions = model.predict(test_inputs)\n    print(\"predictions\")\n    print(predictions)\n    self.assertTrue(np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]])))\n\n  def testParallelCombinationMultipleInputs(self):\n    if self.disable_all:\n      return\n    input_layers = [keras.layers.Input(shape=[1]) for _ in range(3)]\n    all_calibrators = pcl.ParallelCombination(single_output=False)\n    for i in range(3):\n      # Its not typical to use 1-d Lattice layer for calibration, but lets do it\n      # to avoid redundant dependency on PWLCalibration layer.\n      calibrator = ll.Lattice(\n          lattice_sizes=[2], output_min=0.0, output_max=i + 1.0)\n      all_calibrators.append(calibrator)\n\n    # Given output range specified below linear initializer will have lattice to\n    # simply sum up inputs.\n    simple_sum = ll.Lattice(\n        lattice_sizes=[5] * 3,\n        kernel_initializer=\"linear_initializer\",\n        output_min=0.0,\n        output_max=12.0,\n        name=\"SummingLattice\",\n        trainable=False)\n\n    output = simple_sum(all_calibrators(input_layers))\n    model = keras.models.Model(inputs=input_layers, outputs=output)\n\n    test_inputs = [\n        np.asarray([[0.0], [0.1], [1.0]]),\n        np.asarray([[0.0], [0.2], [1.0]]),\n        np.asarray([[0.0], [0.3], [1.0]]),\n    ]\n    predictions = model.predict(test_inputs)\n    print(\"predictions\")\n    print(predictions)\n    self.assertTrue(np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]])))\n\n  def testParallelCombinationClone(self):\n    if self.disable_all:\n      return\n    input_layers = [keras.layers.Input(shape=[1]) for _ in range(3)]\n    all_calibrators = pcl.ParallelCombination(single_output=False)\n    for i in range(3):\n      # Its not typical to use 1-d Lattice layer for calibration, but lets do it\n      # to avoid redundant dependency on PWLCalibration layer.\n      calibrator = ll.Lattice(\n          lattice_sizes=[2], output_min=0.0, output_max=i + 1.0)\n      all_calibrators.append(calibrator)\n\n    # Given output range specified below linear initializer will have lattice to\n    # simply sum up inputs.\n    simple_sum = ll.Lattice(\n        lattice_sizes=[5] * 3,\n        kernel_initializer=\"linear_initializer\",\n        output_min=0.0,\n        output_max=12.0,\n        name=\"SummingLattice\",\n        trainable=False)\n\n    output = simple_sum(all_calibrators(input_layers))\n    model = keras.models.Model(inputs=input_layers, outputs=output)\n    clone = keras.models.clone_model(model)\n\n    test_inputs = [\n        np.asarray([[0.0], [0.1], [1.0]]),\n        np.asarray([[0.0], [0.2], [1.0]]),\n        np.asarray([[0.0], [0.3], [1.0]]),\n    ]\n    predictions = clone.predict(test_inputs)\n    print(\"predictions\")\n    print(predictions)\n    self.assertTrue(np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]])))\n\n    with tempfile.NamedTemporaryFile(suffix=\".h5\") as f:\n      model.save(f.name)\n      loaded_model = keras.models.load_model(\n          f.name,\n          custom_objects={\n              \"ParallelCombination\": pcl.ParallelCombination,\n              \"Lattice\": ll.Lattice,\n          },\n      )\n      predictions = loaded_model.predict(test_inputs)\n      self.assertTrue(\n          np.allclose(predictions, np.asarray([[0.0], [1.4], [6.0]])))\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/premade.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"TF Lattice premade models implement typical monotonic model architectures.\n\nYou can use TFL premade models to easily construct commonly used monotonic model\narchitectures. To construct a TFL premade model, construct a model configuration\nfrom `tfl.configs` and pass it to the premade model constructor. No fields in\nthe model config will be automatically filled in, so the config must be fully\nspecified. Note that the inputs to the model should match the order in which\nthey are defined in the feature configs.\n\n```python\nmodel_config = tfl.configs.CalibratedLatticeConfig(...)\ncalibrated_lattice_model = tfl.premade.CalibratedLattice(\n    model_config=model_config)\ncalibrated_lattice_model.compile(...)\ncalibrated_lattice_model.fit(...)\n```\n\nSupported models are defined in `tfl.configs`. Each model architecture can be\nused the same as any other `keras.Model`.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nfrom . import aggregation_layer\nfrom . import categorical_calibration_layer\nfrom . import configs\nfrom . import kronecker_factored_lattice_layer as kfll\nfrom . import lattice_layer\nfrom . import linear_layer\nfrom . import parallel_combination_layer\nfrom . import premade_lib\nfrom . import pwl_calibration_layer\nfrom . import rtl_layer\n\n\n# TODO: add support for serialization and object scoping or annoations.\nclass CalibratedLatticeEnsemble(keras.Model):\n  \"\"\"Premade model for Tensorflow calibrated lattice ensemble models.\n\n  Creates a `keras.Model` for the model architecture specified by the\n  `model_config`, which should be a\n  `tfl.configs.CalibratedLatticeEnsembleConfig`. No fields in the model config\n  will be automatically filled in, so the config must be fully specified. Note\n  that the inputs to the model should match the order in which they are defined\n  in the feature configs.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeEnsembleConfig(...)\n  calibrated_lattice_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n      model_config=model_config)\n  calibrated_lattice_ensemble_model.compile(...)\n  calibrated_lattice_ensemble_model.fit(...)\n  ```\n\n  Attributes:\n    model_config: Model configuration object describing model architecture.\n      Should be a `tfl.configs.CalibratedLatticeEnsembleConfig` instance.\n  \"\"\"\n\n  def __init__(self, model_config=None, dtype=tf.float32, **kwargs):\n    \"\"\"Initializes a `CalibratedLatticeEnsemble` instance.\n\n    Args:\n      model_config: Model configuration object describing model architecutre.\n        Should be one of the model configs in `tfl.configs`.\n      dtype: dtype of layers used in the model.\n      **kwargs: Any additional `keras.Model` arguments\n    \"\"\"\n    # Set our model_config\n    self.model_config = model_config\n    # Check if we are constructing with already provided inputs/outputs, e.g.\n    # when we are loading a model.\n    if 'inputs' in kwargs and 'outputs' in kwargs:\n      super(CalibratedLatticeEnsemble, self).__init__(**kwargs)\n      return\n    if model_config is None:\n      raise ValueError('Must provide a model_config.')\n    # Check that proper config has been given.\n    if not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):\n      raise ValueError('Invalid config type: {}'.format(type(model_config)))\n    # Verify that the config is fully specified.\n    premade_lib.verify_config(model_config)\n    # Get feature configs and construct model.\n    input_layer = premade_lib.build_input_layer(\n        feature_configs=model_config.feature_configs, dtype=dtype)\n\n    lattice_outputs = premade_lib.build_calibrated_lattice_ensemble_layer(\n        calibration_input_layer=input_layer,\n        model_config=model_config,\n        average_outputs=(not model_config.use_linear_combination),\n        dtype=dtype)\n\n    if model_config.use_linear_combination:\n      averaged_lattice_output = premade_lib.build_linear_combination_layer(\n          ensemble_outputs=lattice_outputs,\n          model_config=model_config,\n          dtype=dtype)\n    else:\n      averaged_lattice_output = lattice_outputs\n\n    if model_config.output_calibration:\n      model_output = premade_lib.build_output_calibration_layer(\n          output_calibration_input=averaged_lattice_output,\n          model_config=model_config,\n          dtype=dtype)\n    else:\n      model_output = averaged_lattice_output\n\n    # Define inputs and initialize model.\n    inputs = [\n        input_layer[feature_config.name]\n        for feature_config in model_config.feature_configs\n    ]\n    kwargs['inputs'] = inputs\n    kwargs['outputs'] = model_output\n    super(CalibratedLatticeEnsemble, self).__init__(**kwargs)\n\n  def get_config(self):\n    \"\"\"Returns a configuration dictionary.\"\"\"\n    config = {'name': self.name, 'trainable': self.trainable}\n    config['model_config'] = keras.utils.legacy.serialize_keras_object(\n        self.model_config\n    )\n    return config\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    model_config = keras.utils.legacy.deserialize_keras_object(\n        config.get('model_config'), custom_objects=custom_objects\n    )\n    premade_lib.verify_config(model_config)\n    return cls(model_config,\n               name=config.get('name', None),\n               trainable=config.get('trainable', True))\n\n\nclass CalibratedLattice(keras.Model):\n  \"\"\"Premade model for Tensorflow calibrated lattice models.\n\n  Creates a `keras.Model` for the model architecture specified by the\n  `model_config`, which should be a `tfl.configs.CalibratedLatticeConfig`. No\n  fields in the model config will be automatically filled in, so the config\n  must be fully specified. Note that the inputs to the model should match the\n  order in which they are defined in the feature configs.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.CalibratedLatticeConfig(...)\n  calibrated_lattice_model = tfl.premade.CalibratedLattice(\n      model_config=model_config)\n  calibrated_lattice_model.compile(...)\n  calibrated_lattice_model.fit(...)\n  ```\n\n  Attributes:\n    model_config: Model configuration object describing model architecture.\n      Should be a `tfl.configs.CalibratedLatticeConfig` instance.\n  \"\"\"\n\n  def __init__(self, model_config=None, dtype=tf.float32, **kwargs):\n    \"\"\"Initializes a `CalibratedLattice` instance.\n\n    Args:\n      model_config: Model configuration object describing model architecutre.\n        Should be one of the model configs in `tfl.configs`.\n      dtype: dtype of layers used in the model.\n      **kwargs: Any additional `keras.Model` arguments.\n    \"\"\"\n    # Set our model_config\n    self.model_config = model_config\n    # Check if we are constructing with already provided inputs/outputs, e.g.\n    # when we are loading a model.\n    if 'inputs' in kwargs and 'outputs' in kwargs:\n      super(CalibratedLattice, self).__init__(**kwargs)\n      return\n    if model_config is None:\n      raise ValueError('Must provide a model_config.')\n    # Check that proper config has been given.\n    if not isinstance(model_config, configs.CalibratedLatticeConfig):\n      raise ValueError('Invalid config type: {}'.format(type(model_config)))\n    # Verify that the config is fully specified.\n    premade_lib.verify_config(model_config)\n    # Get feature configs and construct model.\n    input_layer = premade_lib.build_input_layer(\n        feature_configs=model_config.feature_configs, dtype=dtype)\n    submodels_inputs = premade_lib.build_calibration_layers(\n        calibration_input_layer=input_layer,\n        model_config=model_config,\n        layer_output_range=premade_lib.LayerOutputRange.INPUT_TO_LATTICE,\n        submodels=[[\n            feature_config.name\n            for feature_config in model_config.feature_configs\n        ]],\n        separate_calibrators=False,\n        dtype=dtype)\n\n    lattice_layer_output_range = (\n        premade_lib.LayerOutputRange.INPUT_TO_FINAL_CALIBRATION\n        if model_config.output_calibration else\n        premade_lib.LayerOutputRange.MODEL_OUTPUT)\n    lattice_output = premade_lib.build_lattice_layer(\n        lattice_input=submodels_inputs[0],\n        feature_configs=model_config.feature_configs,\n        model_config=model_config,\n        layer_output_range=lattice_layer_output_range,\n        submodel_index=0,\n        is_inside_ensemble=False,\n        dtype=dtype)\n\n    if model_config.output_calibration:\n      model_output = premade_lib.build_output_calibration_layer(\n          output_calibration_input=lattice_output,\n          model_config=model_config,\n          dtype=dtype)\n    else:\n      model_output = lattice_output\n\n    # Define inputs and initialize model.\n    inputs = [\n        input_layer[feature_config.name]\n        for feature_config in model_config.feature_configs\n    ]\n    kwargs['inputs'] = inputs\n    kwargs['outputs'] = model_output\n    super(CalibratedLattice, self).__init__(**kwargs)\n\n  def get_config(self):\n    \"\"\"Returns a configuration dictionary.\"\"\"\n    config = {'name': self.name, 'trainable': self.trainable}\n    config['model_config'] = keras.utils.legacy.serialize_keras_object(\n        self.model_config\n    )\n    return config\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    model_config = keras.utils.legacy.deserialize_keras_object(\n        config.get('model_config'), custom_objects=custom_objects\n    )\n    premade_lib.verify_config(model_config)\n    return cls(model_config,\n               name=config.get('name', None),\n               trainable=config.get('trainable', True))\n\n\nclass CalibratedLinear(keras.Model):\n  \"\"\"Premade model for Tensorflow calibrated linear models.\n\n  Creates a `keras.Model` for the model architecture specified by the\n  `model_config`, which should be a `tfl.configs.CalibratedLinearConfig`. No\n  fields in the model config will be automatically filled in, so the config\n  must be fully specified. Note that the inputs to the model should match the\n  order in which they are defined in the feature configs.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.CalibratedLinearConfig(...)\n  calibrated_linear_model = tfl.premade.CalibratedLinear(\n      model_config=model_config)\n  calibrated_linear_model.compile(...)\n  calibrated_linear_model.fit(...)\n  ```\n\n  Attributes:\n    model_config: Model configuration object describing model architecture.\n      Should be a `tfl.configs.CalibratedLinearConfig` instance.\n  \"\"\"\n\n  def __init__(self, model_config=None, dtype=tf.float32, **kwargs):\n    \"\"\"Initializes a `CalibratedLinear` instance.\n\n    Args:\n      model_config: Model configuration object describing model architecutre.\n        Should be one of the model configs in `tfl.configs`.\n      dtype: dtype of layers used in the model.\n      **kwargs: Any additional `keras.Model` arguments.\n    \"\"\"\n    # Set our model_config\n    self.model_config = model_config\n    # Check if we are constructing with already provided inputs/outputs, e.g.\n    # when we are loading a model.\n    if 'inputs' in kwargs and 'outputs' in kwargs:\n      super(CalibratedLinear, self).__init__(**kwargs)\n      return\n    if model_config is None:\n      raise ValueError('Must provide a model_config.')\n    # Check that proper config has been given.\n    if not isinstance(model_config, configs.CalibratedLinearConfig):\n      raise ValueError('Invalid config type: {}'.format(type(model_config)))\n    # Verify that the config is fully specified.\n    premade_lib.verify_config(model_config)\n    # Get feature configs and construct model.\n    input_layer = premade_lib.build_input_layer(\n        feature_configs=model_config.feature_configs, dtype=dtype)\n\n    calibration_layer_output_range = (\n        premade_lib.LayerOutputRange.INPUT_TO_FINAL_CALIBRATION\n        if model_config.output_calibration else\n        premade_lib.LayerOutputRange.MODEL_OUTPUT)\n    submodels_inputs = premade_lib.build_calibration_layers(\n        calibration_input_layer=input_layer,\n        model_config=model_config,\n        layer_output_range=calibration_layer_output_range,\n        submodels=[[\n            feature_config.name\n            for feature_config in model_config.feature_configs\n        ]],\n        separate_calibrators=False,\n        dtype=dtype)\n\n    weighted_average = (\n        model_config.output_min is not None or\n        model_config.output_max is not None or model_config.output_calibration)\n    linear_output = premade_lib.build_linear_layer(\n        linear_input=submodels_inputs[0],\n        feature_configs=model_config.feature_configs,\n        model_config=model_config,\n        weighted_average=weighted_average,\n        submodel_index=0,\n        dtype=dtype)\n\n    if model_config.output_calibration:\n      model_output = premade_lib.build_output_calibration_layer(\n          output_calibration_input=linear_output,\n          model_config=model_config,\n          dtype=dtype)\n    else:\n      model_output = linear_output\n\n    # Define inputs and initialize model.\n    inputs = [\n        input_layer[feature_config.name]\n        for feature_config in model_config.feature_configs\n    ]\n    kwargs['inputs'] = inputs\n    kwargs['outputs'] = model_output\n    super(CalibratedLinear, self).__init__(**kwargs)\n\n  def get_config(self):\n    \"\"\"Returns a configuration dictionary.\"\"\"\n    config = {'name': self.name, 'trainable': self.trainable}\n    config['model_config'] = keras.utils.legacy.serialize_keras_object(\n        self.model_config\n    )\n    return config\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    model_config = keras.utils.legacy.deserialize_keras_object(\n        config.get('model_config'), custom_objects=custom_objects\n    )\n    premade_lib.verify_config(model_config)\n    return cls(model_config,\n               name=config.get('name', None),\n               trainable=config.get('trainable', True))\n\n\n# TODO: add support for tf.map_fn and inputs of shape (B, ?, input_dim)\n# as well as non-ragged inputs using padding/mask.\nclass AggregateFunction(keras.Model):\n  \"\"\"Premade model for Tensorflow aggregate function learning models.\n\n  Creates a `keras.Model` for the model architecture specified by the\n  `model_config`, which should be a\n  `tfl.configs.AggregateFunctionConfig`. No\n  fields in the model config will be automatically filled in, so the config\n  must be fully specified. Note that the inputs to the model should match the\n  order in which they are defined in the feature configs. Features will be\n  considered ragged, so inputs to this model must be `tf.ragged` instances.\n\n  Example:\n\n  ```python\n  model_config = tfl.configs.AggregateFunctionConfig(...)\n  agg_model = tfl.premade.AggregateFunction(\n      model_config=model_config)\n  agg_model.compile(...)\n  agg_model.fit(...)\n  ```\n  \"\"\"\n\n  def __init__(self, model_config=None, dtype=tf.float32, **kwargs):\n    \"\"\"Initializes an `AggregateFunction` instance.\n\n    Args:\n      model_config: Model configuration object describing model architecutre.\n        Should be a `tfl.configs.AggregateFunctionConfig` instance.\n      dtype: dtype of layers used in the model.\n      **kwargs: Any additional `keras.Model` arguments.\n    \"\"\"\n    # Set our model_config\n    self.model_config = model_config\n    # Check if we are constructing with already provided inputs/outputs, e.g.\n    # when we are loading a model.\n    if 'inputs' in kwargs and 'outputs' in kwargs:\n      super(AggregateFunction, self).__init__(**kwargs)\n      return\n    if model_config is None:\n      raise ValueError('Must provide a model_config.')\n    # Check that proper config has been given.\n    if not isinstance(model_config, configs.AggregateFunctionConfig):\n      raise ValueError('Invalid config type: {}'.format(type(model_config)))\n    # Verify that the config is fully specified.\n    premade_lib.verify_config(model_config)\n    # Get feature configs and construct model.\n    input_layer = premade_lib.build_input_layer(\n        feature_configs=model_config.feature_configs, dtype=dtype, ragged=True)\n\n    # We need to construct middle_dimension calibrated_lattices for the\n    # aggregation layer. Note that we cannot do this in premade_lib because\n    # importing premade in premade_lib would cause a dependency cycle. Also\n    # note that we only need to set the output initialization to the min and\n    # max since we are not using output calibration at this step of the\n    # aggregation.\n    calibrated_lattice_config = configs.CalibratedLatticeConfig(\n        feature_configs=model_config.feature_configs,\n        interpolation=model_config.aggregation_lattice_interpolation,\n        regularizer_configs=model_config.regularizer_configs,\n        output_min=-1.0,\n        output_max=1.0,\n        output_initialization=[-1.0, 1.0])\n    calibrated_lattice_models = [\n        CalibratedLattice(calibrated_lattice_config)\n        for _ in range(model_config.middle_dimension)\n    ]\n    aggregation_layer_output_range = (\n        premade_lib.LayerOutputRange.INPUT_TO_FINAL_CALIBRATION\n        if model_config.output_calibration else\n        premade_lib.LayerOutputRange.MODEL_OUTPUT)\n    # Change input layer into a list based on model_config.feature_configs.\n    # This is the order of inputs expected by calibrated_lattice_models.\n    inputs = [\n        input_layer[feature_config.name]\n        for feature_config in model_config.feature_configs\n    ]\n    aggregation_output = premade_lib.build_aggregation_layer(\n        aggregation_input_layer=inputs,\n        model_config=model_config,\n        calibrated_lattice_models=calibrated_lattice_models,\n        layer_output_range=aggregation_layer_output_range,\n        submodel_index=0,\n        dtype=dtype)\n\n    if model_config.output_calibration:\n      model_output = premade_lib.build_output_calibration_layer(\n          output_calibration_input=aggregation_output,\n          model_config=model_config,\n          dtype=dtype)\n    else:\n      model_output = aggregation_output\n\n    # Define inputs and initialize model.\n    kwargs['inputs'] = inputs\n    kwargs['outputs'] = model_output\n    super(AggregateFunction, self).__init__(**kwargs)\n\n  def get_config(self):\n    \"\"\"Returns a configuration dictionary.\"\"\"\n    config = {'name': self.name, 'trainable': self.trainable}\n    config['model_config'] = keras.utils.legacy.serialize_keras_object(\n        self.model_config\n    )\n    return config\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    model_config = keras.utils.legacy.deserialize_keras_object(\n        config.get('model_config'), custom_objects=custom_objects\n    )\n    premade_lib.verify_config(model_config)\n    return cls(model_config,\n               name=config.get('name', None),\n               trainable=config.get('trainable', True))\n\n\ndef get_custom_objects(custom_objects=None):\n  \"\"\"Creates and returns a dictionary mapping names to custom objects.\n\n  Args:\n    custom_objects: Optional dictionary mapping names (strings) to custom\n      classes or functions to be considered during deserialization. If provided,\n      the returned mapping will be extended to contain this one.\n\n  Returns:\n    A dictionary mapping names (strings) to tensorflow lattice custom objects.\n  \"\"\"\n  tfl_custom_objects = {\n      'AggregateFunction':\n          AggregateFunction,\n      'AggregateFunctionConfig':\n          configs.AggregateFunctionConfig,\n      'Aggregation':\n          aggregation_layer.Aggregation,\n      'BiasInitializer':\n          kfll.BiasInitializer,\n      'CalibratedLatticeEnsemble':\n          CalibratedLatticeEnsemble,\n      'CalibratedLattice':\n          CalibratedLattice,\n      'CalibratedLatticeConfig':\n          configs.CalibratedLatticeConfig,\n      'CalibratedLatticeEnsembleConfig':\n          configs.CalibratedLatticeEnsembleConfig,\n      'CalibratedLinear':\n          CalibratedLinear,\n      'CalibratedLinearConfig':\n          configs.CalibratedLinearConfig,\n      'CategoricalCalibration':\n          categorical_calibration_layer.CategoricalCalibration,\n      'CategoricalCalibrationConstraints':\n          categorical_calibration_layer.CategoricalCalibrationConstraints,\n      'DominanceConfig':\n          configs.DominanceConfig,\n      'FeatureConfig':\n          configs.FeatureConfig,\n      'KFLRandomMonotonicInitializer':\n          kfll.KFLRandomMonotonicInitializer,\n      'KroneckerFactoredLattice':\n          kfll.KroneckerFactoredLattice,\n      'KroneckerFactoredLatticeConstraints':\n          kfll.KroneckerFactoredLatticeConstraints,\n      'LaplacianRegularizer':\n          lattice_layer.LaplacianRegularizer,\n      'Lattice':\n          lattice_layer.Lattice,\n      'LatticeConstraints':\n          lattice_layer.LatticeConstraints,\n      'Linear':\n          linear_layer.Linear,\n      'LinearConstraints':\n          linear_layer.LinearConstraints,\n      'LinearInitializer':\n          lattice_layer.LinearInitializer,\n      'NaiveBoundsConstraints':\n          pwl_calibration_layer.NaiveBoundsConstraints,\n      'ParallelCombination':\n          parallel_combination_layer.ParallelCombination,\n      'PWLCalibration':\n          pwl_calibration_layer.PWLCalibration,\n      'PWLCalibrationConstraints':\n          pwl_calibration_layer.PWLCalibrationConstraints,\n      'RandomMonotonicInitializer':\n          lattice_layer.RandomMonotonicInitializer,\n      'RegularizerConfig':\n          configs.RegularizerConfig,\n      'RTL':\n          rtl_layer.RTL,\n      'ScaleConstraints':\n          kfll.ScaleConstraints,\n      'ScaleInitializer':\n          kfll.ScaleInitializer,\n      'TorsionRegularizer':\n          lattice_layer.TorsionRegularizer,\n      'TrustConfig':\n          configs.TrustConfig,\n  }\n  if custom_objects is not None:\n    tfl_custom_objects.update(custom_objects)\n  return tfl_custom_objects\n"
  },
  {
    "path": "tensorflow_lattice/python/premade_lib.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implementation of algorithms required for premade models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport copy\nimport enum\nimport itertools\n\nfrom absl import logging\nimport numpy as np\nimport six\n\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nfrom . import aggregation_layer\nfrom . import categorical_calibration_layer\nfrom . import configs\nfrom . import kronecker_factored_lattice_layer as kfll\nfrom . import kronecker_factored_lattice_lib as kfl_lib\nfrom . import lattice_layer\nfrom . import lattice_lib\nfrom . import linear_layer\nfrom . import pwl_calibration_layer\nfrom . import rtl_layer\nfrom . import utils\n\n\n# Layer names used for layers in the premade models.\nAGGREGATION_LAYER_NAME = 'tfl_aggregation'\nCALIB_LAYER_NAME = 'tfl_calib'\nINPUT_LAYER_NAME = 'tfl_input'\nKFL_LAYER_NAME = 'tfl_kronecker_factored_lattice'\nLATTICE_LAYER_NAME = 'tfl_lattice'\nLINEAR_LAYER_NAME = 'tfl_linear'\nOUTPUT_LINEAR_COMBINATION_LAYER_NAME = 'tfl_output_linear_combination'\nOUTPUT_CALIB_LAYER_NAME = 'tfl_output_calib'\nRTL_LAYER_NAME = 'tfl_rtl'\nRTL_INPUT_NAME = 'tfl_rtl_input'\n\n# Prefix for passthrough (identity) nodes for shared calibration.\n# These nodes pass shared calibrated values to submodels in an ensemble.\nCALIB_PASSTHROUGH_NAME = 'tfl_calib_passthrough'\n\n# Prefix for defining feature calibrator regularizers.\n_INPUT_CALIB_REGULARIZER_PREFIX = 'calib_'\n\n# Prefix for defining output calibrator regularizers.\n_OUTPUT_CALIB_REGULARIZER_PREFIX = 'output_calib_'\n\n# Weight of laplacian in feature importance for the crystal algorithm.\n_LAPLACIAN_WEIGHT_IN_IMPORTANCE = 6.0\n\n# Discount amount for repeated co-occurrence of pairs of features in crystals.\n_REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE = 0.5\n\n# Maximum number of swaps for the crystals algorithm.\n_MAX_CRYSTALS_SWAPS = 1000\n\n\ndef _input_calibration_regularizers(model_config, feature_config):\n  \"\"\"Returns pwl layer regularizers defined in the model and feature configs.\"\"\"\n  regularizer_configs = []\n  regularizer_configs.extend(feature_config.regularizer_configs or [])\n  regularizer_configs.extend(model_config.regularizer_configs or [])\n  return [(r.name.replace(_INPUT_CALIB_REGULARIZER_PREFIX, ''), r.l1, r.l2)\n          for r in regularizer_configs\n          if r.name.startswith(_INPUT_CALIB_REGULARIZER_PREFIX)]\n\n\ndef _middle_calibration_regularizers(model_config):\n  \"\"\"Returns pwl layer regularizers defined in the model config.\"\"\"\n  regularizer_configs = []\n  regularizer_configs.extend(model_config.regularizer_configs or [])\n  return [(r.name.replace(_INPUT_CALIB_REGULARIZER_PREFIX, ''), r.l1, r.l2)\n          for r in regularizer_configs\n          if r.name.startswith(_INPUT_CALIB_REGULARIZER_PREFIX)]\n\n\ndef _output_calibration_regularizers(model_config):\n  \"\"\"Returns output calibration regularizers defined in the model config.\"\"\"\n  return [(r.name.replace(_OUTPUT_CALIB_REGULARIZER_PREFIX, ''), r.l1, r.l2)\n          for r in model_config.regularizer_configs or []\n          if r.name.startswith(_OUTPUT_CALIB_REGULARIZER_PREFIX)]\n\n\ndef _lattice_regularizers(model_config, feature_configs):\n  \"\"\"Returns lattice regularizers defined in the model and feature configs.\"\"\"\n  # dict from regularizer name to pair of per feature l1 and l2 amounts.\n  regularizers_dict = {}\n  n_dims = len(feature_configs)\n  for index, feature_config in enumerate(feature_configs):\n    for regularizer_config in feature_config.regularizer_configs or []:\n      if not (\n          regularizer_config.name.startswith(_INPUT_CALIB_REGULARIZER_PREFIX) or\n          regularizer_config.name.startswith(_OUTPUT_CALIB_REGULARIZER_PREFIX)):\n        if regularizer_config.name not in regularizers_dict:\n          regularizers_dict[regularizer_config.name] = ([0.0] * n_dims,\n                                                        [0.0] * n_dims)\n        regularizers_dict[\n            regularizer_config.name][0][index] += regularizer_config.l1\n        regularizers_dict[\n            regularizer_config.name][1][index] += regularizer_config.l2\n\n  regularizers = [(k,) + v for k, v in regularizers_dict.items()]\n\n  for regularizer_config in model_config.regularizer_configs or []:\n    if not (\n        regularizer_config.name.startswith(_INPUT_CALIB_REGULARIZER_PREFIX) or\n        regularizer_config.name.startswith(_OUTPUT_CALIB_REGULARIZER_PREFIX)):\n      regularizers.append((regularizer_config.name, regularizer_config.l1,\n                           regularizer_config.l2))\n  return regularizers\n\n\nclass LayerOutputRange(enum.Enum):\n  \"\"\"Enum to indicate the output range based on the input of the next layers.\"\"\"\n  MODEL_OUTPUT = 1\n  INPUT_TO_LATTICE = 2\n  INPUT_TO_FINAL_CALIBRATION = 3\n\n\ndef _output_range(layer_output_range, model_config, feature_config=None):\n  \"\"\"Returns min/max/init_min/init_max for a given output range.\"\"\"\n  if layer_output_range == LayerOutputRange.INPUT_TO_LATTICE:\n    if feature_config is None:\n      raise ValueError('Expecting feature config for lattice inputs.')\n    output_init_min = output_min = 0.0\n    output_init_max = output_max = feature_config.lattice_size - 1.0\n  elif layer_output_range == LayerOutputRange.MODEL_OUTPUT:\n    output_min = model_config.output_min\n    output_max = model_config.output_max\n    # Note: due to the multiplicative nature of KroneckerFactoredLattice layers,\n    # the initialization min/max do not correspond directly to the output\n    # min/max. Thus we follow the same scheme as the KroneckerFactoredLattice\n    # lattice layer to properly initialize the kernel and scale such that\n    # the output does in fact respect the requested bounds.\n    if ((isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or\n         isinstance(model_config, configs.CalibratedLatticeConfig)) and\n        model_config.parameterization == 'kronecker_factored'):\n      output_init_min, output_init_max = kfl_lib.default_init_params(\n          output_min, output_max)\n    else:\n      output_init_min = np.min(model_config.output_initialization)\n      output_init_max = np.max(model_config.output_initialization)\n  elif layer_output_range == LayerOutputRange.INPUT_TO_FINAL_CALIBRATION:\n    output_init_min = output_min = 0.0\n    output_init_max = output_max = 1.0\n  else:\n    raise ValueError('Unsupported layer output range.')\n  return output_min, output_max, output_init_min, output_init_max\n\n\ndef build_input_layer(feature_configs, dtype, ragged=False):\n  \"\"\"Creates a mapping from feature name to `keras.Input`.\n\n  Args:\n    feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n      specify configurations for each feature.\n    dtype: dtype\n    ragged: If the inputs are ragged tensors.\n\n  Returns:\n    Mapping from feature name to `keras.Input` for the inputs specified by\n      `feature_configs`.\n  \"\"\"\n  input_layer = {}\n  shape = (None,) if ragged else (1,)\n  for feature_config in feature_configs:\n    layer_name = '{}_{}'.format(INPUT_LAYER_NAME, feature_config.name)\n    if feature_config.num_buckets:\n      input_layer[feature_config.name] = keras.Input(\n          shape=shape, ragged=ragged, dtype=tf.int32, name=layer_name)\n    else:\n      input_layer[feature_config.name] = keras.Input(\n          shape=shape, ragged=ragged, dtype=dtype, name=layer_name)\n  return input_layer\n\n\ndef build_multi_unit_calibration_layers(calibration_input_layer,\n                                        calibration_output_units, model_config,\n                                        layer_output_range,\n                                        output_single_tensor, dtype):\n  \"\"\"Creates a mapping from feature names to calibration outputs.\n\n  Args:\n    calibration_input_layer: A mapping from feature name to `keras.Input`.\n    calibration_output_units: A mapping from feature name to units.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.\n    output_single_tensor: If output for each feature should be a single tensor.\n    dtype: dtype\n\n  Returns:\n    A mapping from feature name to calibration output Tensors.\n  \"\"\"\n  calibration_output = {}\n  for feature_name, units in calibration_output_units.items():\n    if units == 0:\n      raise ValueError(\n          'Feature {} is not used. Calibration output units is 0.'.format(\n              feature_name))\n    feature_config = model_config.feature_config_by_name(feature_name)\n    calibration_input = calibration_input_layer[feature_name]\n    layer_name = '{}_{}'.format(CALIB_LAYER_NAME, feature_name)\n\n    (output_min, output_max, output_init_min,\n     output_init_max) = _output_range(layer_output_range, model_config,\n                                      feature_config)\n\n    if feature_config.num_buckets:\n      kernel_initializer = keras.initializers.RandomUniform(\n          output_init_min, output_init_max)\n      calibrated = (\n          categorical_calibration_layer.CategoricalCalibration(\n              num_buckets=feature_config.num_buckets,\n              units=units,\n              output_min=output_min,\n              output_max=output_max,\n              kernel_initializer=kernel_initializer,\n              monotonicities=feature_config.monotonicity if isinstance(\n                  feature_config.monotonicity, list) else None,\n              default_input_value=feature_config.default_value,\n              split_outputs=(units > 1 and not output_single_tensor),\n              dtype=dtype,\n              name=layer_name)(calibration_input))\n    else:\n      kernel_regularizer = _input_calibration_regularizers(\n          model_config, feature_config)\n      monotonicity = feature_config.monotonicity\n      if (utils.canonicalize_monotonicity(monotonicity) == 0 and\n          feature_config.pwl_calibration_always_monotonic):\n        monotonicity = 1\n      kernel_initializer = pwl_calibration_layer.UniformOutputInitializer(\n          output_min=output_init_min,\n          output_max=output_init_max,\n          monotonicity=monotonicity,\n          keypoints=feature_config.pwl_calibration_input_keypoints)\n      calibrated = (\n          pwl_calibration_layer.PWLCalibration(\n              units=units,\n              input_keypoints=feature_config.pwl_calibration_input_keypoints,\n              output_min=output_min,\n              output_max=output_max,\n              clamp_min=feature_config.pwl_calibration_clamp_min,\n              clamp_max=feature_config.pwl_calibration_clamp_max,\n              missing_input_value=feature_config.default_value,\n              impute_missing=(feature_config.default_value is not None),\n              kernel_initializer=kernel_initializer,\n              kernel_regularizer=kernel_regularizer,\n              monotonicity=monotonicity,\n              convexity=feature_config.pwl_calibration_convexity,\n              split_outputs=(units > 1 and not output_single_tensor),\n              input_keypoints_type=feature_config\n              .pwl_calibration_input_keypoints_type,\n              dtype=dtype,\n              name=layer_name)(calibration_input))\n    if output_single_tensor:\n      calibration_output[feature_name] = calibrated\n    elif units == 1:\n      calibration_output[feature_name] = [calibrated]\n    else:\n      # calibrated will have already been split in this case.\n      calibration_output[feature_name] = calibrated\n  return calibration_output\n\n\ndef build_calibration_layers(calibration_input_layer, model_config,\n                             layer_output_range, submodels,\n                             separate_calibrators, dtype):\n  \"\"\"Creates a calibration layer for `submodels` as list of list of features.\n\n  Args:\n    calibration_input_layer: A mapping from feature name to `keras.Input`.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.\n    submodels: A list of list of feature names.\n    separate_calibrators: If features should be separately calibrated for each\n      lattice in an ensemble.\n    dtype: dtype\n\n  Returns:\n    A list of list of Tensors representing a calibration layer for `submodels`.\n  \"\"\"\n  # Create a list of (feature_name, calibration_output_idx) pairs for each\n  # submodel. When using shared calibration, all submodels will have\n  # calibration_output_idx = 0.\n  submodels_input_features = []\n  calibration_last_index = collections.defaultdict(int)\n  for submodel in submodels:\n    submodel_input_features = []\n    submodels_input_features.append(submodel_input_features)\n    for feature_name in submodel:\n      submodel_input_features.append(\n          (feature_name, calibration_last_index[feature_name]))\n      if separate_calibrators:\n        calibration_last_index[feature_name] += 1\n\n  # This is to account for shared calibration.\n  calibration_output_units = {\n      name: max(index, 1) for name, index in calibration_last_index.items()\n  }\n  calibration_output = build_multi_unit_calibration_layers(\n      calibration_input_layer=calibration_input_layer,\n      calibration_output_units=calibration_output_units,\n      model_config=model_config,\n      layer_output_range=layer_output_range,\n      output_single_tensor=False,\n      dtype=dtype)\n\n  # Create passthrough nodes for each submodel input so that we can recover\n  # the model structure for plotting and analysis.\n  # {CALIB_PASSTHROUGH_NAME}_{feature_name}_\n  #   {calibration_output_idx}_{submodel_idx}_{submodel_input_idx}\n  submodels_inputs = []\n  for submodel_idx, submodel_input_features in enumerate(\n      submodels_input_features):\n    submodel_inputs = []\n    submodels_inputs.append(submodel_inputs)\n    for (submodel_input_idx,\n         (feature_name,\n          calibration_output_idx)) in enumerate(submodel_input_features):\n      passthrough_name = '{}_{}_{}_{}_{}'.format(CALIB_PASSTHROUGH_NAME,\n                                                 feature_name,\n                                                 calibration_output_idx,\n                                                 submodel_idx,\n                                                 submodel_input_idx)\n      submodel_inputs.append(\n          tf.identity(\n              calibration_output[feature_name][calibration_output_idx],\n              name=passthrough_name))\n\n  return submodels_inputs\n\n\ndef build_aggregation_layer(aggregation_input_layer, model_config,\n                            calibrated_lattice_models, layer_output_range,\n                            submodel_index, dtype):\n  \"\"\"Creates an aggregation layer using the given calibrated lattice models.\n\n  Args:\n    aggregation_input_layer: A list or a mapping from feature name to\n      `keras.Input`, in the order or format expected by\n      `calibrated_lattice_models`.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    calibrated_lattice_models: A list of calibrated lattice models of size\n      model_config.middle_diemnsion, where each calbirated lattice model\n      instance is constructed using the same model configuration object.\n    layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.\n    submodel_index: Corresponding index into submodels.\n    dtype: dtype\n\n  Returns:\n    A list of list of Tensors representing a calibration layer for `submodels`.\n  \"\"\"\n  (output_min, output_max, output_init_min,\n   output_init_max) = _output_range(layer_output_range, model_config)\n\n  lattice_sizes = [model_config.middle_lattice_size\n                  ] * model_config.middle_dimension\n  lattice_monotonicities = [1] * model_config.middle_dimension\n\n  # Create the aggergated embeddings to pass to the middle lattice.\n  lattice_inputs = []\n  for i in range(model_config.middle_dimension):\n    agg_layer_name = '{}_{}'.format(AGGREGATION_LAYER_NAME, i)\n    agg_output = aggregation_layer.Aggregation(\n        calibrated_lattice_models[i], name=agg_layer_name)(\n            aggregation_input_layer)\n    agg_output = keras.layers.Reshape((1,))(agg_output)\n    if model_config.middle_calibration:\n      agg_output = pwl_calibration_layer.PWLCalibration(\n          input_keypoints=np.linspace(\n              -1.0,\n              1.0,\n              num=model_config.middle_calibration_num_keypoints,\n              dtype=np.float32),\n          output_min=0.0,\n          output_max=lattice_sizes[i] - 1.0,\n          monotonicity=utils.canonicalize_monotonicity(\n              model_config.middle_monotonicity),\n          kernel_regularizer=_middle_calibration_regularizers(model_config),\n          input_keypoints_type=model_config\n          .middle_calibration_input_keypoints_type,\n          dtype=dtype,\n      )(\n          agg_output)\n      agg_output = keras.layers.Reshape((1,))(agg_output)\n    lattice_inputs.append(agg_output)\n\n  # We use random monotonic initialization here to break the symmetry that we\n  # would otherwise have between middle lattices. Since we use the same\n  # CalibratedLattice for each of the middle dimensions, if we do not randomly\n  # initialize the middle lattice we will have the same gradient flow back for\n  # each middle dimension, thus acting the same as if there was only one middle\n  # dimension.\n  kernel_initializer = lattice_layer.RandomMonotonicInitializer(\n      lattice_sizes=lattice_sizes,\n      output_min=output_init_min,\n      output_max=output_init_max)\n  lattice_layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)\n  return lattice_layer.Lattice(\n      lattice_sizes=lattice_sizes,\n      monotonicities=lattice_monotonicities,\n      output_min=output_min,\n      output_max=output_max,\n      clip_inputs=False,\n      interpolation=model_config.middle_lattice_interpolation,\n      kernel_initializer=kernel_initializer,\n      dtype=dtype,\n      name=lattice_layer_name,\n  )(\n      lattice_inputs)\n\n\ndef _monotonicities_from_feature_configs(feature_configs):\n  \"\"\"Returns list of monotonicities defined in the given feature_configs.\"\"\"\n  monotonicities = []\n  for feature_config in feature_configs:\n    if not feature_config.monotonicity:\n      monotonicities.append(0)\n    elif (isinstance(feature_config.monotonicity, six.string_types) and\n          feature_config.monotonicity.lower() == 'none'):\n      monotonicities.append(0)\n    else:\n      monotonicities.append(1)\n  return monotonicities\n\n\ndef _dominance_constraints_from_feature_configs(feature_configs):\n  \"\"\"Returns list of dominance constraints in the given feature_configs.\"\"\"\n  feature_names = [feature_config.name for feature_config in feature_configs]\n  monotonic_dominances = []\n  for dominant_idx, dominant_feature_config in enumerate(feature_configs):\n    for dominance_config in dominant_feature_config.dominates or []:\n      if dominance_config.feature_name in feature_names:\n        weak_idx = feature_names.index(dominance_config.feature_name)\n        if dominance_config.dominance_type == 'monotonic':\n          monotonic_dominances.append((dominant_idx, weak_idx))\n        else:\n          raise ValueError('Unrecognized dominance type: {}'.format(\n              dominance_config.dominance_type))\n  return monotonic_dominances\n\n\ndef _canonical_feature_names(model_config, feature_names=None):\n  if feature_names is not None:\n    return feature_names\n  if model_config.feature_configs is None:\n    raise ValueError(\n        'Feature configs must be specified if feature names are not provided.')\n  return [\n      feature_config.name for feature_config in model_config.feature_configs\n  ]\n\n\ndef build_linear_layer(linear_input, feature_configs, model_config,\n                       weighted_average, submodel_index, dtype):\n  \"\"\"Creates a `tfl.layers.Linear` layer initialized to be an average.\n\n  Args:\n    linear_input: Input to the linear layer.\n    feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n      specify configurations for each feature.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    weighted_average: If the linear coefficients should be positive and sum up\n      to one.\n    submodel_index: Corresponding index into submodels.\n    dtype: dtype\n\n  Returns:\n    A `tfl.layers.Linear` instance.\n  \"\"\"\n  layer_name = '{}_{}'.format(LINEAR_LAYER_NAME, submodel_index)\n\n  linear_input = keras.layers.Concatenate(axis=1)(linear_input)\n  num_input_dims = len(feature_configs)\n  kernel_initializer = keras.initializers.Constant([1.0 / num_input_dims] *\n                                                      num_input_dims)\n  bias_initializer = keras.initializers.Constant(0)\n\n  if weighted_average:\n    # Linear coefficients should be possitive and sum up to one.\n    linear_monotonicities = [1] * num_input_dims\n    normalization_order = 1\n    use_bias = False\n  else:\n    linear_monotonicities = _monotonicities_from_feature_configs(\n        feature_configs)\n    normalization_order = None\n    use_bias = model_config.use_bias\n\n  monotonic_dominances = _dominance_constraints_from_feature_configs(\n      feature_configs)\n\n  return linear_layer.Linear(\n      num_input_dims=num_input_dims,\n      monotonicities=linear_monotonicities,\n      monotonic_dominances=monotonic_dominances,\n      use_bias=use_bias,\n      normalization_order=normalization_order,\n      kernel_initializer=kernel_initializer,\n      bias_initializer=bias_initializer,\n      dtype=dtype,\n      name=layer_name)(\n          linear_input)\n\n\ndef build_lattice_layer(lattice_input, feature_configs, model_config,\n                        layer_output_range, submodel_index, is_inside_ensemble,\n                        dtype):\n  \"\"\"Creates a `tfl.layers.Lattice` layer.\n\n  Args:\n    lattice_input: Input to the lattice layer.\n    feature_configs: A list of `tfl.configs.FeatureConfig` instances that\n      specify configurations for each feature.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.\n    submodel_index: Corresponding index into submodels.\n    is_inside_ensemble: If this layer is inside an ensemble.\n    dtype: dtype\n\n  Returns:\n    A `tfl.layers.Lattice` instance if `model_config.parameterization` is set to\n    `'all_vertices'` or a `tfl.layers.KroneckerFactoredLattice` instance if\n    set to `'kronecker_factored'`.\n\n  Raises:\n    ValueError: If `model_config.parameterization` is not one of\n      `'all_vertices'` or `'kronecker_factored'`.\n  \"\"\"\n  layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)\n\n  (output_min, output_max, output_init_min,\n   output_init_max) = _output_range(layer_output_range, model_config)\n\n  feature_names = [feature_config.name for feature_config in feature_configs]\n  lattice_sizes = [\n      feature_config.lattice_size for feature_config in feature_configs\n  ]\n  lattice_monotonicities = _monotonicities_from_feature_configs(feature_configs)\n  lattice_unimodalities = [\n      feature_config.unimodality for feature_config in feature_configs\n  ]\n  lattice_regularizers = _lattice_regularizers(model_config,\n                                               feature_configs) or None\n\n  # Construct trust constraints within this lattice.\n  edgeworth_trusts = []\n  trapezoid_trusts = []\n  for conditional_idx, conditional_feature_config in enumerate(feature_configs):\n    for trust_config in conditional_feature_config.reflects_trust_in or []:\n      if trust_config.feature_name in feature_names:\n        main_idx = feature_names.index(trust_config.feature_name)\n        if trust_config.trust_type == 'edgeworth':\n          edgeworth_trusts.append(\n              (main_idx, conditional_idx, trust_config.direction))\n        elif trust_config.trust_type == 'trapezoid':\n          trapezoid_trusts.append(\n              (main_idx, conditional_idx, trust_config.direction))\n        else:\n          raise ValueError('Unrecognized trust type: {}'.format(\n              trust_config.trust_type))\n      elif is_inside_ensemble and trust_config.trust_type == 'trapezoid':\n        logging.warning(\n            'A \"main\" feature (%s) for a trapezoid trust constraint is not '\n            'present in a lattice that includes the \"conditional\" feature '\n            '(%s). In an ensemble model, this can result in constraint '\n            'violations. Consider manually setting the ensemble structure if '\n            'this constraint needs to be satisfied.', trust_config.feature_name,\n            conditional_feature_config.name)\n\n  monotonic_dominances = _dominance_constraints_from_feature_configs(\n      feature_configs)\n\n  if model_config.parameterization == 'all_vertices':\n    layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)\n    kernel_initializer = lattice_layer.LinearInitializer(\n        lattice_sizes=lattice_sizes,\n        monotonicities=lattice_monotonicities,\n        unimodalities=lattice_unimodalities,\n        output_min=output_init_min,\n        output_max=output_init_max)\n    return lattice_layer.Lattice(\n        lattice_sizes=lattice_sizes,\n        monotonicities=lattice_monotonicities,\n        unimodalities=lattice_unimodalities,\n        edgeworth_trusts=edgeworth_trusts,\n        trapezoid_trusts=trapezoid_trusts,\n        monotonic_dominances=monotonic_dominances,\n        output_min=output_min,\n        output_max=output_max,\n        clip_inputs=False,\n        interpolation=model_config.interpolation,\n        kernel_regularizer=lattice_regularizers,\n        kernel_initializer=kernel_initializer,\n        dtype=dtype,\n        name=layer_name)(\n            lattice_input)\n  elif model_config.parameterization == 'kronecker_factored':\n    layer_name = '{}_{}'.format(KFL_LAYER_NAME, submodel_index)\n    kernel_initializer = kfll.KFLRandomMonotonicInitializer(\n        monotonicities=lattice_monotonicities,\n        init_min=output_init_min,\n        init_max=output_init_max,\n        seed=model_config.random_seed)\n    scale_initializer = kfll.ScaleInitializer(\n        output_min=output_min, output_max=output_max)\n    return kfll.KroneckerFactoredLattice(\n        lattice_sizes=lattice_sizes[0],\n        num_terms=model_config.num_terms,\n        monotonicities=lattice_monotonicities,\n        output_min=output_min,\n        output_max=output_max,\n        clip_inputs=False,\n        kernel_initializer=kernel_initializer,\n        scale_initializer=scale_initializer,\n        dtype=dtype,\n        name=layer_name)(\n            lattice_input)\n  else:\n    raise ValueError('Unknown type of parameterization: {}'.format(\n        model_config.parameterization))\n\n\ndef build_lattice_ensemble_layer(submodels_inputs, model_config, dtype):\n  \"\"\"Creates an ensemble of `tfl.layers.Lattice` layers.\n\n  Args:\n    submodels_inputs: List of inputs to each of the lattice layers in the\n      ensemble. The order corresponds to the elements of model_config.lattices.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    dtype: dtype\n\n  Returns:\n    A list of `tfl.layers.Lattice` instances.\n  \"\"\"\n  lattice_outputs = []\n  for submodel_index, (lattice_feature_names, lattice_input) in enumerate(\n      zip(model_config.lattices, submodels_inputs)):\n    lattice_feature_configs = [\n        model_config.feature_config_by_name(feature_name)\n        for feature_name in lattice_feature_names\n    ]\n    lattice_layer_output_range = (\n        LayerOutputRange.INPUT_TO_FINAL_CALIBRATION\n        if model_config.output_calibration else LayerOutputRange.MODEL_OUTPUT)\n    lattice_outputs.append(\n        build_lattice_layer(\n            lattice_input=lattice_input,\n            feature_configs=lattice_feature_configs,\n            model_config=model_config,\n            layer_output_range=lattice_layer_output_range,\n            submodel_index=submodel_index,\n            is_inside_ensemble=True,\n            dtype=dtype))\n  return lattice_outputs\n\n\ndef build_rtl_layer(calibration_outputs, model_config, submodel_index,\n                    average_outputs, dtype):\n  \"\"\"Creates a `tfl.layers.RTL` layer.\n\n  This function expects that all features defined in\n  model_config.feature_configs are used and present in calibration_outputs.\n\n  Args:\n    calibration_outputs: A mapping from feature name to calibration output.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    submodel_index: Corresponding index into submodels.\n    average_outputs: Whether to average the outputs of this layer.\n    dtype: dtype\n\n  Returns:\n    A `tfl.layers.RTL` instance.\n\n  Raises:\n    ValueError: If `model_config.parameterization` is not one of\n      `'all_vertices'` or `'kronecker_factored'`.\n  \"\"\"\n  layer_name = '{}_{}'.format(RTL_LAYER_NAME, submodel_index)\n\n  rtl_layer_output_range = (\n      LayerOutputRange.INPUT_TO_FINAL_CALIBRATION\n      if model_config.output_calibration else LayerOutputRange.MODEL_OUTPUT)\n\n  (output_min, output_max, output_init_min,\n   output_init_max) = _output_range(rtl_layer_output_range, model_config)\n\n  lattice_regularizers = _lattice_regularizers(\n      model_config, model_config.feature_configs) or None\n\n  rtl_inputs = collections.defaultdict(list)\n  for feature_config in model_config.feature_configs:\n    passthrough_name = '{}_{}'.format(RTL_INPUT_NAME, feature_config.name)\n    calibration_output = tf.identity(\n        calibration_outputs[feature_config.name], name=passthrough_name)\n    if feature_config.monotonicity in [1, -1, 'increasing', 'decreasing']:\n      rtl_inputs['increasing'].append(calibration_output)\n    else:\n      rtl_inputs['unconstrained'].append(calibration_output)\n\n  lattice_size = model_config.feature_configs[0].lattice_size\n  if model_config.parameterization == 'all_vertices':\n    kernel_initializer = 'random_monotonic_initializer'\n  elif model_config.parameterization == 'kronecker_factored':\n    kernel_initializer = 'kfl_random_monotonic_initializer'\n  else:\n    raise ValueError('Unknown type of parameterization: {}'.format(\n        model_config.parameterization))\n  return rtl_layer.RTL(\n      num_lattices=model_config.num_lattices,\n      lattice_rank=model_config.lattice_rank,\n      lattice_size=lattice_size,\n      output_min=output_min,\n      output_max=output_max,\n      init_min=output_init_min,\n      init_max=output_init_max,\n      random_seed=model_config.random_seed,\n      clip_inputs=False,\n      interpolation=model_config.interpolation,\n      parameterization=model_config.parameterization,\n      num_terms=model_config.num_terms,\n      kernel_regularizer=lattice_regularizers,\n      kernel_initializer=kernel_initializer,\n      average_outputs=average_outputs,\n      dtype=dtype,\n      name=layer_name)(\n          rtl_inputs)\n\n\ndef build_calibrated_lattice_ensemble_layer(calibration_input_layer,\n                                            model_config, average_outputs,\n                                            dtype):\n  \"\"\"Creates a calibration layer followed by a lattice ensemble layer.\n\n  Args:\n    calibration_input_layer: A mapping from feature name to `keras.Input`.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    average_outputs: Whether to average the outputs of this layer.\n    dtype: dtype\n\n  Returns:\n    A `tfl.layers.RTL` instance if model_config.lattices is 'rtl_layer.\n    Otherwise a list of `tfl.layers.Lattice` instances.\n  \"\"\"\n  if model_config.lattices == 'rtl_layer':\n    num_features = len(model_config.feature_configs)\n    units = [1] * num_features\n    if model_config.separate_calibrators:\n      num_inputs = model_config.num_lattices * model_config.lattice_rank\n      # We divide the number of inputs semi-evenly by the number of features.\n      # TODO: support setting number of calibration units.\n      for i in range(num_features):\n        units[i] = ((i + 1) * num_inputs // num_features -\n                    i * num_inputs // num_features)\n    calibration_output_units = {\n        feature_config.name: units[i]\n        for i, feature_config in enumerate(model_config.feature_configs)\n    }\n    calibration_outputs = build_multi_unit_calibration_layers(\n        calibration_input_layer=calibration_input_layer,\n        calibration_output_units=calibration_output_units,\n        model_config=model_config,\n        layer_output_range=LayerOutputRange.INPUT_TO_LATTICE,\n        output_single_tensor=True,\n        dtype=dtype)\n\n    lattice_outputs = build_rtl_layer(\n        calibration_outputs=calibration_outputs,\n        model_config=model_config,\n        submodel_index=0,\n        average_outputs=average_outputs,\n        dtype=dtype)\n  else:\n    submodels_inputs = build_calibration_layers(\n        calibration_input_layer=calibration_input_layer,\n        model_config=model_config,\n        layer_output_range=LayerOutputRange.INPUT_TO_LATTICE,\n        submodels=model_config.lattices,\n        separate_calibrators=model_config.separate_calibrators,\n        dtype=dtype)\n\n    lattice_outputs = build_lattice_ensemble_layer(\n        submodels_inputs=submodels_inputs,\n        model_config=model_config,\n        dtype=dtype)\n\n    if average_outputs:\n      lattice_outputs = keras.layers.Average()(lattice_outputs)\n\n  return lattice_outputs\n\n\ndef build_linear_combination_layer(ensemble_outputs, model_config, dtype):\n  \"\"\"Creates a `tfl.layers.Linear` layer initialized to be an average.\n\n  Args:\n    ensemble_outputs: Ensemble outputs to be linearly combined.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    dtype: dtype\n\n  Returns:\n    A `tfl.layers.Linear` instance.\n  \"\"\"\n  if isinstance(ensemble_outputs, list):\n    num_input_dims = len(ensemble_outputs)\n    linear_input = keras.layers.Concatenate(axis=1)(ensemble_outputs)\n  else:\n    num_input_dims = int(ensemble_outputs.shape[1])\n    linear_input = ensemble_outputs\n  kernel_initializer = keras.initializers.Constant(1.0 / num_input_dims)\n  bias_initializer = keras.initializers.Constant(0)\n\n  if (not model_config.output_calibration and\n      model_config.output_min is None and model_config.output_max is None):\n    normalization_order = None\n  else:\n    # We need to use weighted average to keep the output range.\n    normalization_order = 1\n    # Bias term cannot be used when this layer should have bounded output.\n    if model_config.use_bias:\n      raise ValueError('Cannot use a bias term in linear combination with '\n                       'output bounds or output calibration')\n\n  return linear_layer.Linear(\n      num_input_dims=num_input_dims,\n      monotonicities=['increasing'] * num_input_dims,\n      normalization_order=normalization_order,\n      use_bias=model_config.use_bias,\n      kernel_initializer=kernel_initializer,\n      bias_initializer=bias_initializer,\n      dtype=dtype,\n      name=OUTPUT_LINEAR_COMBINATION_LAYER_NAME)(\n          linear_input)\n\n\ndef build_output_calibration_layer(output_calibration_input, model_config,\n                                   dtype):\n  \"\"\"Creates a monotonic output calibration layer with inputs range [0, 1].\n\n  Args:\n    output_calibration_input: Input to the output calibration layer.\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    dtype: dtype\n\n  Returns:\n    A `tfl.layers.PWLCalibration` instance.\n  \"\"\"\n  # kernel format: bias followed by diffs between consecutive keypoint outputs.\n  kernel_init_values = np.ediff1d(\n      model_config.output_initialization,\n      to_begin=model_config.output_initialization[0])\n  input_keypoints = np.linspace(0.0, 1.0, num=len(kernel_init_values))\n  kernel_initializer = keras.initializers.Constant(kernel_init_values)\n  kernel_regularizer = _output_calibration_regularizers(model_config)\n  return pwl_calibration_layer.PWLCalibration(\n      input_keypoints=input_keypoints,\n      output_min=model_config.output_min,\n      output_max=model_config.output_max,\n      kernel_initializer=kernel_initializer,\n      kernel_regularizer=kernel_regularizer,\n      monotonicity=1,\n      input_keypoints_type=model_config.output_calibration_input_keypoints_type,\n      dtype=dtype,\n      name=OUTPUT_CALIB_LAYER_NAME)(\n          output_calibration_input)\n\n\ndef set_categorical_monotonicities(feature_configs):\n  \"\"\"Maps categorical monotonicities to indices based on specified vocab list.\n\n  Args:\n    feature_configs: A list of `tfl.configs.FeatureConfig` objects.\n  \"\"\"\n  if not isinstance(feature_configs, list) or any(\n      not isinstance(fc, configs.FeatureConfig) for fc in feature_configs):\n    raise ValueError(\n        'feature_configs must be a list of tfl.configs.FeatureConfig objects: '\n        '{}'.format(feature_configs))\n  for feature_config in feature_configs:\n    if feature_config.num_buckets and isinstance(feature_config.monotonicity,\n                                                 list):\n      # Make sure the vocabulary list exists. If not, assume user has already\n      # properly set monotonicity as proper indices for this calibrator.\n      if not feature_config.vocabulary_list:\n        continue\n      if not all(\n          isinstance(m, (list, tuple)) and len(m) == 2\n          for m in feature_config.monotonicity):\n        raise ValueError(\n            'Monotonicities should be a list of pairs (list/tuples): {}'.format(\n                feature_config.monotonicity))\n      indexed_monotonicities = []\n      index_map = {\n          category: index\n          for (index, category) in enumerate(feature_config.vocabulary_list)\n      }\n      if feature_config.default_value is not None:\n        index_map[feature_config.default_value] = feature_config.num_buckets - 1\n      for left, right in feature_config.monotonicity:\n        for category in [left, right]:\n          if category not in index_map:\n            raise ValueError(\n                'Category `{}` not found in vocabulary list for feature `{}`'\n                .format(category, feature_config.name))\n        indexed_monotonicities.append((index_map[left], index_map[right]))\n\n      feature_config.monotonicity = indexed_monotonicities\n\n\ndef set_random_lattice_ensemble(model_config, feature_names=None):\n  \"\"\"Sets random lattice ensemble in the given model_config.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n    feature_names: A list of feature names. If not provided, feature names will\n      be extracted from the feature configs contained in the model_config.\n  \"\"\"\n  if not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):\n    raise ValueError(\n        'model_config must be a tfl.configs.CalibratedLatticeEnsembleConfig: {}'\n        .format(type(model_config)))\n  if model_config.lattices != 'random':\n    raise ValueError('model_config.lattices must be set to \\'random\\'.')\n  feature_names = _canonical_feature_names(model_config, feature_names)\n  # Start by using each feature once.\n  np.random.seed(model_config.random_seed)\n  model_config.lattices = [[] for _ in range(model_config.num_lattices)]\n  for feature_name in feature_names:\n    non_full_indices = [\n        i for (i, lattice) in enumerate(model_config.lattices)\n        if len(lattice) < model_config.lattice_rank\n    ]\n    model_config.lattices[np.random.choice(non_full_indices)].append(\n        feature_name)\n\n  # Fill up lattices avoiding repeated features.\n  for lattice in model_config.lattices:\n    feature_names_not_in_lattice = [\n        feature_name for feature_name in feature_names\n        if feature_name not in lattice\n    ]\n    remaining_size = model_config.lattice_rank - len(lattice)\n    lattice.extend(\n        np.random.choice(\n            feature_names_not_in_lattice, size=remaining_size, replace=False))\n\n\ndef _add_pair_to_ensemble(lattices, lattice_rank, i, j):\n  \"\"\"Adds pair (i, j) to the ensemble heuristically.\"\"\"\n  # First check if (i, j) pair is already present in a lattice.\n  for lattice in lattices:\n    if i in lattice and j in lattice:\n      return\n\n  # Try adding to a lattice that already has either i or j.\n  for lattice in lattices:\n    if len(lattice) < lattice_rank:\n      if i in lattice:\n        lattice.add(j)\n        return\n      if j in lattice:\n        lattice.add(i)\n        return\n\n  # Add both i and j to a lattice that has enough space left.\n  for lattice in lattices:\n    if len(lattice) < lattice_rank - 1:\n      lattice.add(i)\n      lattice.add(j)\n      return\n\n  # Create a new lattice with pair (i, j).\n  lattices.append(set([i, j]))\n\n\ndef _set_all_pairs_cover_lattices(prefitting_model_config, feature_names):\n  \"\"\"Sets prefitting lattice ensemble such that it covers all feature pairs.\"\"\"\n  # Pairs of co-occurrence that need to exist in the all-pairs cover.\n  to_cover = list(itertools.combinations(range(len(feature_names)), 2))\n  np.random.seed(prefitting_model_config.random_seed)\n  np.random.shuffle(to_cover)\n\n  lattices = []\n\n  for (i, j) in to_cover:\n    _add_pair_to_ensemble(lattices, prefitting_model_config.lattice_rank, i, j)\n\n  prefitting_model_config.lattices = [\n      [feature_names[i] for i in lattice] for lattice in lattices\n  ]\n\n\ndef construct_prefitting_model_config(model_config, feature_names=None):\n  \"\"\"Constructs a model config for a prefitting model for crystal extraction.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be a `tfl.configs.CalibratedLatticeEnsemble` instance.\n    feature_names: A list of feature names. If not provided, feature names will\n      be extracted from the feature configs contained in the model_config.\n\n  Returns:\n    A `tfl.configs.CalibratedLatticeEnsembleConfig` instance.\n  \"\"\"\n  if not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):\n    raise ValueError(\n        'model_config must be a tfl.configs.CalibratedLatticeEnsembleConfig: {}'\n        .format(type(model_config)))\n  if model_config.lattices != 'crystals':\n    raise ValueError('model_config.lattices must be set to \\'crystals\\'.')\n  feature_names = _canonical_feature_names(model_config, feature_names)\n\n  if len(feature_names) <= model_config.lattice_rank:\n    raise ValueError(\n        'model_config.lattice_rank must be less than the number of features '\n        'when using \\'crystals\\' algorithm. If you want to use all features in '\n        'every lattice, set model_config.lattices to \\'random\\'.')\n\n  # Make a copy of the model config provided and set all pairs covered.\n  prefitting_model_config = copy.deepcopy(model_config)\n  # Set parameterization of prefitting model to 'all_vertices' to extract\n  # crystals using normal lattice because we do not have laplacian/torsion\n  # regularizers for KFL. This should still extract could feature combinations.\n  prefitting_model_config.parameterization = 'all_vertices'\n  _set_all_pairs_cover_lattices(\n      prefitting_model_config=prefitting_model_config,\n      feature_names=feature_names)\n\n  # Trim the model for faster prefitting.\n  for feature_config in prefitting_model_config.feature_configs:\n    feature_config.lattice_size = 2\n    # Unimodality requires lattice_size > 2.\n    feature_config.unimodality = 0\n    # Disable 2d constraints to avoid potential constraint violations.\n    feature_config.dominates = None\n    feature_config.reflects_trust_in = None\n\n  # Return our properly constructed prefitting model config.\n  return prefitting_model_config\n\n\ndef _verify_prefitting_model(prefitting_model, feature_names):\n  \"\"\"Checks that prefitting_model has the proper input layer.\"\"\"\n  if isinstance(prefitting_model, keras.Model):\n    layer_names = [layer.name for layer in prefitting_model.layers]\n  elif hasattr(prefitting_model, 'get_variable_names'):  # estimator\n    layer_names = prefitting_model.get_variable_names()\n  else:\n    raise ValueError('Invalid model type for prefitting_model: {}'.format(\n        type(prefitting_model)))\n  for feature_name in feature_names:\n    if isinstance(prefitting_model, keras.Model):\n      input_layer_name = '{}_{}'.format(INPUT_LAYER_NAME, feature_name)\n      if input_layer_name not in layer_names:\n        raise ValueError(\n            'prefitting_model does not match prefitting_model_config. Make '\n            'sure that prefitting_model is the proper type and constructed '\n            'from the prefitting_model_config: {}'.format(\n                type(prefitting_model)))\n    else:\n      pwl_input_layer_name = '{}_{}/{}'.format(\n          CALIB_LAYER_NAME, feature_name,\n          pwl_calibration_layer.PWL_CALIBRATION_KERNEL_NAME)\n      cat_input_layer_name = '{}_{}/{}'.format(\n          CALIB_LAYER_NAME, feature_name,\n          categorical_calibration_layer.CATEGORICAL_CALIBRATION_KERNEL_NAME)\n      if (pwl_input_layer_name not in layer_names and\n          cat_input_layer_name not in layer_names):\n        raise ValueError(\n            'prefitting_model does not match prefitting_model_config. Make '\n            'sure that prefitting_model is the proper type and constructed '\n            'from the prefitting_model_config: {}'.format(\n                type(prefitting_model)))\n\n\ndef _get_lattice_weights(prefitting_model, lattice_index):\n  \"\"\"Gets the weights of the lattice at the specfied index.\"\"\"\n  if isinstance(prefitting_model, keras.Model):\n    lattice_layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, lattice_index)\n    weights = keras.backend.get_value(\n        prefitting_model.get_layer(lattice_layer_name).weights[0])\n  else:\n    # We have already checked the types by this point, so if prefitting_model\n    # is not a keras Model it must be an Estimator.\n    lattice_kernel_variable_name = '{}_{}/{}'.format(\n        LATTICE_LAYER_NAME, lattice_index, lattice_layer.LATTICE_KERNEL_NAME)\n    weights = prefitting_model.get_variable_value(lattice_kernel_variable_name)\n  return weights\n\n\ndef _get_torsions_and_laplacians(prefitting_model_config, prefitting_model,\n                                 feature_names):\n  \"\"\"Returns average torsion and laplacian regularizers in prefitted model.\"\"\"\n  num_fatures = len(feature_names)\n  laplacians = [[] for _ in range(num_fatures)]\n  torsions = [[[] for _ in range(num_fatures)] for _ in range(num_fatures)]\n  for (lattice_index, lattice) in enumerate(prefitting_model_config.lattices):\n    # Get lattice weights and normalize them.\n    weights = _get_lattice_weights(prefitting_model, lattice_index)\n    weights -= np.min(weights)\n    weights /= np.max(weights)\n    weights = tf.constant(weights)\n\n    # Convert feature names in the lattice to their index in feature_names.\n    lattice = [feature_names.index(feature_name) for feature_name in lattice]\n    lattice_sizes = [2] * len(lattice)\n    # feature_* refers to feature index in feature_names.\n    # within_lattice_index_* is the index of input dimenstion of the lattice.\n    for within_lattice_index_0, feature_0 in enumerate(lattice):\n      l2 = [0] * len(lattice)\n      l2[within_lattice_index_0] = 1\n      laplacians[feature_0].append(\n          lattice_lib.laplacian_regularizer(\n              weights=weights, lattice_sizes=lattice_sizes, l2=l2))\n      for within_lattice_index_1, feature_1 in enumerate(lattice):\n        if within_lattice_index_1 > within_lattice_index_0:\n          l2 = [0] * len(lattice)\n          l2[within_lattice_index_0] = 1\n          l2[within_lattice_index_1] = 1\n          torsion = lattice_lib.torsion_regularizer(\n              weights=weights, lattice_sizes=lattice_sizes, l2=l2)\n          torsions[feature_0][feature_1].append(torsion)\n          torsions[feature_1][feature_0].append(torsion)\n\n  if not tf.executing_eagerly():\n    with tf.compat.v1.Session() as sess:\n      laplacians = sess.run(laplacians)\n      torsions = sess.run(torsions)\n\n  laplacians = [np.mean(v) for v in laplacians]\n  torsions = [[np.mean(v) if v else 0.0 for v in row] for row in torsions]\n  return torsions, laplacians\n\n\ndef _get_final_crystal_lattices(model_config, prefitting_model_config,\n                                prefitting_model, feature_names):\n  \"\"\"Extracts the lattice ensemble structure from the prefitting model.\"\"\"\n  torsions, laplacians = _get_torsions_and_laplacians(\n      prefitting_model_config=prefitting_model_config,\n      prefitting_model=prefitting_model,\n      feature_names=feature_names)\n\n  # Calculate features' importance_score = lambda * laplacians + torsion.\n  # Used to allocate slots to useful features with more non-linear interactions.\n  num_features = len(feature_names)\n  importance_scores = np.array(laplacians) * _LAPLACIAN_WEIGHT_IN_IMPORTANCE\n  for feature_0, feature_1 in itertools.combinations(range(num_features), 2):\n    importance_scores[feature_0] += torsions[feature_0][feature_1]\n    importance_scores[feature_1] += torsions[feature_0][feature_1]\n\n  # Each feature is used at least once, and the remaining slots are distributed\n  # proportional to the importance_scores.\n  features_uses = [1] * num_features\n  total_feature_use = model_config.num_lattices * model_config.lattice_rank\n  remaining_uses = total_feature_use - num_features\n  remaining_scores = np.sum(importance_scores)\n  for feature in np.argsort(-importance_scores):\n    added_uses = int(\n        round(remaining_uses * importance_scores[feature] / remaining_scores))\n    # Each feature cannot be used more than once in a finalized lattice.\n    added_uses = min(added_uses, model_config.num_lattices - 1)\n    features_uses[feature] += added_uses\n    remaining_uses -= added_uses\n    remaining_scores -= importance_scores[feature]\n  assert np.sum(features_uses) == total_feature_use\n\n  # Add features to add list in round-robin order.\n  add_list = []\n  for use in range(1, max(features_uses) + 1):\n    for feature_index, feature_use in enumerate(features_uses):\n      if use <= feature_use:\n        add_list.append(feature_index)\n  assert len(add_list) == total_feature_use\n\n  # Setup initial lattices that will be optimized by swapping later.\n  lattices = [[] for _ in range(model_config.num_lattices)]\n  cooccurrence_counts = [[0] * num_features for _ in range(num_features)]\n  for feature_to_be_added in add_list:\n    # List of pairs of (addition_score, candidate_lattice_to_add_to).\n    score_candidates_pairs = []\n    for candidate_lattice_to_add_to in range(model_config.num_lattices):\n      # addition_score indicates the priority of an addition.\n      if len(\n          lattices[candidate_lattice_to_add_to]) >= model_config.lattice_rank:\n        # going out of bound on the lattice\n        addition_score = -2.0\n      elif feature_to_be_added in lattices[candidate_lattice_to_add_to]:\n        # repeates (fixed repeats later by swapping)\n        addition_score = -1.0\n      elif not lattices[candidate_lattice_to_add_to]:\n        # adding a new lattice roughly has an \"average\" lattice score\n        addition_score = np.mean(torsions) * model_config.lattice_rank**2 / 2\n      else:\n        # all other cases: change in total discounted torsion after addition.\n        addition_score = 0.0\n        for other_feature in lattices[candidate_lattice_to_add_to]:\n          addition_score += (\n              torsions[feature_to_be_added][other_feature] *\n              _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE\n              **(cooccurrence_counts[feature_to_be_added][other_feature]))\n\n      score_candidates_pairs.append(\n          (addition_score, candidate_lattice_to_add_to))\n\n    # Use the highest scoring addition.\n    score_candidates_pairs.sort(reverse=True)\n    best_candidate_lattice_to_add_to = score_candidates_pairs[0][1]\n    for other_feature in lattices[best_candidate_lattice_to_add_to]:\n      cooccurrence_counts[feature_to_be_added][other_feature] += 1\n      cooccurrence_counts[other_feature][feature_to_be_added] += 1\n    lattices[best_candidate_lattice_to_add_to].append(feature_to_be_added)\n\n  # Apply swapping operations to increase within-lattice torsion.\n  changed = True\n  iteration = 0\n  while changed:\n    if iteration > _MAX_CRYSTALS_SWAPS:\n      logging.info('Crystals algorithm did not fully converge.')\n      break\n    changed = False\n    iteration += 1\n    for lattice_0, lattice_1 in itertools.combinations(lattices, 2):\n      # For every pair of lattices: lattice_0, lattice_1\n      for index_0, index_1 in itertools.product(\n          range(len(lattice_0)), range(len(lattice_1))):\n        # Consider swapping lattice_0[index_0] with lattice_1[index_1]\n        rest_lattice_0 = list(lattice_0)\n        rest_lattice_1 = list(lattice_1)\n        feature_0 = rest_lattice_0.pop(index_0)\n        feature_1 = rest_lattice_1.pop(index_1)\n        if feature_0 == feature_1:\n          continue\n\n        # Calculate the change in the overall discounted sum of torsion terms.\n        added_cooccurrence = set(\n            [tuple(sorted((feature_1, other))) for other in rest_lattice_0] +\n            [tuple(sorted((feature_0, other))) for other in rest_lattice_1])\n        removed_cooccurrence = set(\n            [tuple(sorted((feature_0, other))) for other in rest_lattice_0] +\n            [tuple(sorted((feature_1, other))) for other in rest_lattice_1])\n        wash = added_cooccurrence.intersection(removed_cooccurrence)\n        added_cooccurrence = added_cooccurrence.difference(wash)\n        removed_cooccurrence = removed_cooccurrence.difference(wash)\n        swap_diff_torsion = (\n            sum(torsions[i][j] * _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE**\n                cooccurrence_counts[i][j] for (i, j) in added_cooccurrence) -\n            sum(torsions[i][j] * _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE**\n                (cooccurrence_counts[i][j] - 1)\n                for (i, j) in removed_cooccurrence))\n\n        # Swap if a feature is repeated or if the score change is positive.\n        if (feature_0 not in lattice_1 and feature_1 not in lattice_0 and\n            (lattice_0.count(feature_0) > 1 or lattice_1.count(feature_1) > 1 or\n             swap_diff_torsion > 0)):\n          for (i, j) in added_cooccurrence:\n            cooccurrence_counts[i][j] += 1\n            cooccurrence_counts[j][i] += 1\n          for (i, j) in removed_cooccurrence:\n            cooccurrence_counts[i][j] -= 1\n            cooccurrence_counts[j][i] -= 1\n          lattice_0[index_0], lattice_1[index_1] = (lattice_1[index_1],\n                                                    lattice_0[index_0])\n          changed = True\n  # Return the extracted lattice structure.\n  return lattices\n\n\ndef set_crystals_lattice_ensemble(model_config,\n                                  prefitting_model_config,\n                                  prefitting_model,\n                                  feature_names=None):\n  \"\"\"Extracts crystals from a prefitting model and finalizes model_config.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be a `tfl.configs.CalibratedLatticeEnsemble` instance.\n    prefitting_model_config: Model configuration object describing prefitting\n      model architecture. Should be a `tfl.configs.CalibratedLatticeEnsemble`\n      insance constructed using\n      `tfl.premade_lib.construct_prefitting_model_config`.\n    prefitting_model: A trained `tfl.premade.CalibratedLatticeEnsemble`,\n      `tfl.estimators.CannedEstimator`, `tfl.estimators.CannedClassifier`, or\n      `tfl.estiamtors.CannedRegressor` instance.\n    feature_names: A list of feature names. If not provided, feature names will\n      be extracted from the feature configs contained in the model_config.\n  \"\"\"\n  # Error checking parameter types.\n  if not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):\n    raise ValueError(\n        'model_config must be a tfl.configs.CalibratedLatticeEnsembleConfig: {}'\n        .format(type(model_config)))\n  if not isinstance(prefitting_model_config,\n                    configs.CalibratedLatticeEnsembleConfig):\n    raise ValueError('prefitting_model_config must be a '\n                     'tfl.configs.CalibratedLatticeEnsembleConfig: {}'.format(\n                         type(model_config)))\n  if model_config.lattices != 'crystals':\n    raise ValueError('model_config.lattices must be set to \\'crystals\\'.')\n  # Note that we cannot check the type of the prefitting model without importing\n  # premade/estimators, which would cause a cyclic dependency. However, we can\n  # check that the model is a keras.Model or tf.Estimator instance that has\n  # the proper input layers matching prefitting_model_config feature_configs.\n  # Beyond that, a prefitting_model with proper input layer names that is not of\n  # the proper type will have undefined behavior.\n  # To perform this check, we must first extract feature names if they are not\n  # provided, which we need for later steps anyway.\n  feature_names = _canonical_feature_names(model_config, feature_names)\n  _verify_prefitting_model(prefitting_model, feature_names)\n\n  # Now we can extract the crystals and finalize model_config.\n  lattices = _get_final_crystal_lattices(\n      model_config=model_config,\n      prefitting_model_config=prefitting_model_config,\n      prefitting_model=prefitting_model,\n      feature_names=feature_names)\n  model_config.lattices = [[\n      feature_names[features_index] for features_index in lattice\n  ] for lattice in lattices]\n\n\ndef _weighted_quantile(sorted_values, quantiles, weights):\n  \"\"\"Calculates weighted quantiles of the given sorted and unique values.\"\"\"\n  if len(sorted_values) < len(quantiles):\n    raise ValueError(\n        'Not enough unique values ({}) to calculate {} quantiles.'.format(\n            len(sorted_values), len(quantiles)))\n  # Weighted quantiles of the observed (sorted) values.\n  # Weights are spread equaly before and after the observed values.\n  weighted_quantiles = (np.cumsum(weights) - 0.5 * weights) / np.sum(weights)\n\n  # Use linear interpolation to find index of the quantile values.\n  index_values = np.arange(len(sorted_values))\n  quantiles_idx = np.interp(x=quantiles, xp=weighted_quantiles, fp=index_values)\n  quantiles_idx = np.rint(quantiles_idx).astype(int)\n\n  # Replace repeated quantile values with neighbouring values.\n  unique_idx, first_use = np.unique(quantiles_idx, return_index=True)\n  used_idx = set(unique_idx)\n  num_values = len(sorted_values)\n  for i in range(len(quantiles_idx)):\n    if i not in first_use:\n      # Since this is not the first use of a (repeated) quantile value, we will\n      # need to find an unused neighbouring value.\n      for delta, direction in itertools.product(range(1, num_values), [-1, 1]):\n        candidate_idx = quantiles_idx[i] + direction * delta\n        if (candidate_idx >= 0 and candidate_idx < num_values and\n            candidate_idx not in used_idx):\n          used_idx.add(candidate_idx)\n          quantiles_idx[i] = candidate_idx\n          break\n  quantiles_idx = np.sort(quantiles_idx)\n\n  return sorted_values[quantiles_idx]\n\n\ndef compute_keypoints(values,\n                      num_keypoints,\n                      keypoints='quantiles',\n                      clip_min=None,\n                      clip_max=None,\n                      default_value=None,\n                      weights=None,\n                      weight_reduction='mean',\n                      feature_name=''):\n  \"\"\"Calculates keypoints for the given set of values.\n\n  Args:\n    values: Values to use for quantile calculation.\n    num_keypoints: Number of keypoints to compute.\n    keypoints: String `'quantiles'` or `'uniform'`.\n    clip_min: Input values are lower clipped by this value.\n    clip_max: Input values are upper clipped by this value.\n    default_value: If provided, occurances will be removed from values.\n    weights: Weights to be used for quantile calculation.\n    weight_reduction: Reduction applied to weights for repeated values. Must be\n      either 'mean' or 'sum'.\n    feature_name: Name to use for error logs.\n\n  Returns:\n    A list of keypoints of `num_keypoints` length.\n  \"\"\"\n  # Remove default values before calculating stats.\n  non_default_idx = values != default_value\n  values = values[non_default_idx]\n  if weights is not None:\n    weights = weights[non_default_idx]\n\n  # Clip min and max if requested. Note that we add clip bounds to the values\n  # so that the first and last keypoints are set to those values.\n  if clip_min is not None:\n    values = np.maximum(values, clip_min)\n    values = np.append(values, clip_min)\n    if weights is not None:\n      weights = np.append(weights, 0)\n  if clip_max is not None:\n    values = np.minimum(values, clip_max)\n    values = np.append(values, clip_max)\n    if weights is not None:\n      weights = np.append(weights, 0)\n\n  # We do not allow nans in the data, even as default_value.\n  if np.isnan(values).any():\n    raise ValueError(\n        'NaN values were observed for numeric feature `{}`. '\n        'Consider replacing the values in transform or input_fn.'.format(\n            feature_name))\n\n  # Remove duplicates and sort value before calculating stats.\n  # This is emperically useful as we use of keypoints more efficiently.\n  if weights is None:\n    sorted_values = np.unique(values)\n  else:\n    # First sort the values and reorder weights.\n    idx = np.argsort(values)\n    values = values[idx]\n    weights = weights[idx]\n\n    # Set the weight of each unique element to be the sum or average of the\n    # weights of repeated instances. Using 'mean' reduction results in parity\n    # between unweighted calculation and having equal weights for all values.\n    sorted_values, idx, counts = np.unique(\n        values, return_index=True, return_counts=True)\n    weights = np.add.reduceat(weights, idx)\n    if weight_reduction == 'mean':\n      weights = weights / counts\n    elif weight_reduction != 'sum':\n      raise ValueError('Invalid weight reduction: {}'.format(weight_reduction))\n\n  if keypoints == 'quantiles':\n    if sorted_values.size < num_keypoints:\n      logging.info(\n          'Not enough unique values observed for feature `%s` to '\n          'construct %d keypoints for pwl calibration. Using %d unique '\n          'values as keypoints.', feature_name, num_keypoints,\n          sorted_values.size)\n      return sorted_values.astype(float)\n\n    quantiles = np.linspace(0., 1., num_keypoints)\n    if weights is not None:\n      return _weighted_quantile(\n          sorted_values=sorted_values, quantiles=quantiles,\n          weights=weights).astype(float)\n    else:\n      return np.quantile(\n          sorted_values, quantiles, interpolation='nearest').astype(float)\n\n  elif keypoints == 'uniform':\n    return np.linspace(sorted_values[0], sorted_values[-1], num_keypoints)\n  else:\n    raise ValueError('Invalid keypoint generation mode: {}'.format(keypoints))\n\n\ndef _feature_config_by_name(feature_configs, feature_name, add_if_missing):\n  \"\"\"Returns feature_config with the given name.\"\"\"\n  for feature_config in feature_configs:\n    if feature_config.name == feature_name:\n      return feature_config\n  # Use the default FeatureConfig if not present.\n  feature_config = configs.FeatureConfig(feature_name)\n  if add_if_missing:\n    feature_configs.append(feature_config)\n  return feature_config\n\n\ndef compute_feature_keypoints(feature_configs,\n                              features,\n                              weights=None,\n                              weight_reduction='mean'):\n  \"\"\"Computes feature keypoints with the data provide in `features` dict.\"\"\"\n  # Calculate feature keypoitns.\n  feature_keypoints = {}\n  for feature_name, values in six.iteritems(features):\n    feature_config = _feature_config_by_name(\n        feature_configs=feature_configs,\n        feature_name=feature_name,\n        add_if_missing=False)\n\n    if feature_config.num_buckets:\n      # Skip categorical features.\n      continue\n    if isinstance(feature_config.pwl_calibration_input_keypoints, str):\n      feature_keypoints[feature_name] = compute_keypoints(\n          values,\n          num_keypoints=feature_config.pwl_calibration_num_keypoints,\n          keypoints=feature_config.pwl_calibration_input_keypoints,\n          clip_min=feature_config.pwl_calibration_clip_min,\n          clip_max=feature_config.pwl_calibration_clip_max,\n          default_value=feature_config.default_value,\n          weights=weights,\n          weight_reduction=weight_reduction,\n          feature_name=feature_name,\n      )\n    else:\n      # User-specified keypoint values.\n      feature_keypoints[\n          feature_name] = feature_config.pwl_calibration_input_keypoints\n  return feature_keypoints\n\n\ndef set_feature_keypoints(feature_configs, feature_keypoints,\n                          add_missing_feature_configs):\n  \"\"\"Updates the feature configs with provided keypoints.\"\"\"\n  for feature_name, keypoints in six.iteritems(feature_keypoints):\n    feature_config = _feature_config_by_name(\n        feature_configs=feature_configs,\n        feature_name=feature_name,\n        add_if_missing=add_missing_feature_configs)\n    feature_config.pwl_calibration_input_keypoints = keypoints\n\n\ndef compute_label_keypoints(model_config,\n                            labels,\n                            logits_output,\n                            weights=None,\n                            weight_reduction='mean'):\n  \"\"\"Computes label keypoints with the data provide in `lables` array.\"\"\"\n  if not np.issubdtype(labels[0], np.number):\n    # Default feature_values to [0, ... n_class-1] for string labels.\n    labels = np.arange(len(set(labels)))\n    weights = None\n\n  if isinstance(model_config.output_initialization, str):\n    # If model is expected to produce logits, initialize linearly in the\n    # range [-2, 2], ignoring the label distribution.\n    if logits_output:\n      return np.linspace(-2, 2, model_config.output_calibration_num_keypoints)\n\n    return compute_keypoints(\n        labels,\n        num_keypoints=model_config.output_calibration_num_keypoints,\n        keypoints=model_config.output_initialization,\n        clip_min=model_config.output_min,\n        clip_max=model_config.output_max,\n        weights=weights,\n        weight_reduction=weight_reduction,\n        feature_name='label',\n    )\n  else:\n    # User-specified keypoint values.\n    return model_config.output_initialization\n\n\ndef set_label_keypoints(model_config, label_keypoints):\n  \"\"\"Updates the label keypoints in the `model_config`.\"\"\"\n  model_config.output_initialization = label_keypoints\n\n\ndef _verify_ensemble_config(model_config):\n  \"\"\"Verifies that an ensemble model and feature configs are properly specified.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n\n  Raises:\n    ValueError: If `model_config.lattices` is set to 'rtl_layer' and\n      `model_config.num_lattices` is not specified.\n    ValueError: If `model_config.num_lattices < 2`.\n    ValueError: If `model_config.lattices` is set to 'rtl_layer' and\n      `lattice_size` is not the same for all features.\n    ValueError: If `model_config.lattices` is set to 'rtl_layer' and\n      there are features with unimodality constraints.\n    ValueError: If `model_config.lattices` is set to 'rtl_layer' and\n      there are features with trust constraints.\n    ValueError: If `model_config.lattices` is set to 'rtl_layer' and\n      there are features with dominance constraints.\n    ValueError: If `model_config.lattices` is set to 'rtl_layer' and\n      there are per-feature lattice regularizers.\n    ValueError: If `model_config.lattices` is not iterable or constaints\n      non-string values.\n    ValueError: If `model_config.lattices` is not set to 'rtl_layer' or a fully\n      specified list of lists of feature names.\n  \"\"\"\n  if model_config.lattices == 'rtl_layer':\n    # RTL must have num_lattices specified and >= 2.\n    if model_config.num_lattices is None:\n      raise ValueError('model_config.num_lattices must be specified when '\n                       'model_config.lattices is set to \\'rtl_layer\\'.')\n    if model_config.num_lattices < 2:\n      raise ValueError(\n          'CalibratedLatticeEnsemble must have >= 2 lattices. For single '\n          'lattice models, use CalibratedLattice instead.')\n    # Check that all lattices sizes for all features are the same.\n    if any(feature_config.lattice_size !=\n           model_config.feature_configs[0].lattice_size\n           for feature_config in model_config.feature_configs):\n      raise ValueError('RTL Layer must have the same lattice size for all '\n                       'features.')\n    # Check that there are only monotonicity and bound constraints.\n    if any(\n        feature_config.unimodality != 'none' and feature_config.unimodality != 0\n        for feature_config in model_config.feature_configs):\n      raise ValueError(\n          'RTL Layer does not currently support unimodality constraints.')\n    if any(feature_config.reflects_trust_in is not None\n           for feature_config in model_config.feature_configs):\n      raise ValueError(\n          'RTL Layer does not currently support trust constraints.')\n    if any(feature_config.dominates is not None\n           for feature_config in model_config.feature_configs):\n      raise ValueError(\n          'RTL Layer does not currently support dominance constraints.')\n    # Check that there are no per-feature lattice regularizers.\n    for feature_config in model_config.feature_configs:\n      for regularizer_config in feature_config.regularizer_configs or []:\n        if not regularizer_config.name.startswith(\n            _INPUT_CALIB_REGULARIZER_PREFIX):\n          raise ValueError(\n              'RTL Layer does not currently support per-feature lattice '\n              'regularizers.')\n  elif isinstance(model_config.lattices, list):\n    # Make sure there are more than one lattice. If not, tell user to use\n    # CalibratedLattice instead.\n    if len(model_config.lattices) < 2:\n      raise ValueError(\n          'CalibratedLatticeEnsemble must have >= 2 lattices. For single '\n          'lattice models, use CalibratedLattice instead.')\n    for lattice in model_config.lattices:\n      if (not np.iterable(lattice) or\n          any(not isinstance(x, str) for x in lattice)):\n        raise ValueError(\n            'Lattices are not fully specified for ensemble config.')\n  else:\n    raise ValueError(\n        'Lattices are not fully specified for ensemble config. Lattices must '\n        'be set to \\'rtl_layer\\' or be fully specified as a list of lists of '\n        'feature names.')\n\n\ndef _verify_kronecker_factored_config(model_config):\n  \"\"\"Verifies that a kronecker_factored model_config is properly specified.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n\n  Raises:\n    ValueError: If there are lattice regularizers.\n    ValueError: If there are per-feature lattice regularizers.\n    ValueError: If there are unimodality constraints.\n    ValueError: If there are trust constraints.\n    ValueError: If there are dominance constraints.\n  \"\"\"\n  for regularizer_config in model_config.regularizer_configs or []:\n    if not regularizer_config.name.startswith(_INPUT_CALIB_REGULARIZER_PREFIX):\n      raise ValueError(\n          'KroneckerFactoredLattice layer does not currently support '\n          'lattice regularizers.')\n  for feature_config in model_config.feature_configs:\n    for regularizer_config in feature_config.regularizer_configs or []:\n      if not regularizer_config.name.startswith(\n          _INPUT_CALIB_REGULARIZER_PREFIX):\n        raise ValueError(\n            'KroneckerFactoredLattice layer does not currently support '\n            'per-feature lattice regularizers.')\n  # Check that all lattices sizes for all features are the same.\n  if any(feature_config.lattice_size !=\n         model_config.feature_configs[0].lattice_size\n         for feature_config in model_config.feature_configs):\n    raise ValueError('KroneckerFactoredLattice layer must have the same '\n                     'lattice size for all features.')\n  # Check that there are only monotonicity and bound constraints.\n  if any(\n      feature_config.unimodality != 'none' and feature_config.unimodality != 0\n      for feature_config in model_config.feature_configs):\n    raise ValueError(\n        'KroneckerFactoredLattice layer does not currently support unimodality '\n        'constraints.')\n  if any(feature_config.reflects_trust_in is not None\n         for feature_config in model_config.feature_configs):\n    raise ValueError(\n        'KroneckerFactoredLattice layer does not currently support trust '\n        'constraints.')\n  if any(feature_config.dominates is not None\n         for feature_config in model_config.feature_configs):\n    raise ValueError(\n        'KroneckerFactoredLattice layer does not currently support dominance '\n        'constraints.')\n\n\ndef _verify_aggregate_function_config(model_config):\n  \"\"\"Verifies that an aggregate function model_config is properly specified.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n\n  Raises:\n    ValueError: If `middle_dimension < 1`.\n    ValueError: If `model_config.middle_monotonicity` is not None and\n      `model_config.middle_calibration` is not True.\n  \"\"\"\n  if model_config.middle_dimension < 1:\n    raise ValueError('Middle dimension must be at least 1: {}'.format(\n        model_config.middle_dimension))\n  if (model_config.middle_monotonicity is not None and\n      not model_config.middle_calibration):\n    raise ValueError(\n        'middle_calibration must be true when middle_monotonicity is '\n        'specified.')\n\n\ndef _verify_feature_config(feature_config):\n  \"\"\"Verifies that feature_config is properly specified.\n\n  Args:\n    feature_config: Feature configuration object describing an input feature to\n      a model. Should be an instance of `tfl.configs.FeatureConfig`.\n\n  Raises:\n    ValueError: If `feature_config.pwl_calibration_input_keypoints` is not\n      iterable or contains non-{int/float} values for a numerical feature.\n    ValueError: If `feature_config.monotonicity` is not an iterable for a\n      categorical feature.\n    ValueError: If any element in `feature_config.monotonicity` is not an\n      iterable for a categorical feature.\n    ValueError: If any value in any element in `feature_config.monotonicity` is\n      not an int for a categorical feature.\n    ValueError: If any value in any element in `feature_config.monotonicity` is\n      not in the range `[0, feature_config.num_buckets]` for a categorical\n      feature.\n  \"\"\"\n  if not feature_config.num_buckets:\n    # Validate PWL Calibration configuration.\n    if (not np.iterable(feature_config.pwl_calibration_input_keypoints) or\n        any(not isinstance(x, (int, float))\n            for x in feature_config.pwl_calibration_input_keypoints)):\n      raise ValueError('Input keypoints are invalid for feature {}: {}'.format(\n          feature_config.name, feature_config.pwl_calibration_input_keypoints))\n  elif feature_config.monotonicity and feature_config.monotonicity != 'none':\n    # Validate Categorical Calibration configuration.\n    if not np.iterable(feature_config.monotonicity):\n      raise ValueError('Monotonicity is not a list for feature {}: {}'.format(\n          feature_config.name, feature_config.monotonicity))\n    for i, t in enumerate(feature_config.monotonicity):\n      if not np.iterable(t):\n        raise ValueError(\n            'Element {} is not a list/tuple for feature {} monotonicty: {}'\n            .format(i, feature_config.name, t))\n      for j, val in enumerate(t):\n        if not isinstance(val, int):\n          raise ValueError(\n              'Element {} for list/tuple {} for feature {} monotonicity is '\n              'not an index: {}'.format(j, i, feature_config.name, val))\n        if val < 0 or val >= feature_config.num_buckets:\n          raise ValueError(\n              'Element {} for list/tuple {} for feature {} monotonicity is '\n              'an invalid index not in range [0, num_buckets - 1]: {}'.format(\n                  j, i, feature_config.name, val))\n\n\ndef verify_config(model_config):\n  \"\"\"Verifies that the model_config and feature_configs are properly specified.\n\n  Args:\n    model_config: Model configuration object describing model architecture.\n      Should be one of the model configs in `tfl.configs`.\n\n  Raises:\n    ValueError: If `model_config.feature_configs` is None.\n    ValueError: If `model_config.output_initialization` is not iterable or\n      contains non-{int/float} values.\n\n  \"\"\"\n  if model_config.feature_configs is None:\n    raise ValueError('Feature configs must be fully specified.')\n  if isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):\n    _verify_ensemble_config(model_config)\n  if ((isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or\n       isinstance(model_config, configs.CalibratedLatticeConfig)) and\n      model_config.parameterization == 'kronecker_factored'):\n    _verify_kronecker_factored_config(model_config)\n  if isinstance(model_config, configs.AggregateFunctionConfig):\n    _verify_aggregate_function_config(model_config)\n  for feature_config in model_config.feature_configs:\n    _verify_feature_config(feature_config)\n  if (not np.iterable(model_config.output_initialization) or\n      any(not isinstance(x, (int, float))\n          for x in model_config.output_initialization)):\n    raise ValueError('Output initilization is invalid: {}'.format(\n        model_config.output_initialization))\n"
  },
  {
    "path": "tensorflow_lattice/python/premade_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Tensorflow Lattice premade.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport copy\nimport json\n\nimport tempfile\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport pandas as pd\nimport tensorflow as tf\nfrom tensorflow_lattice.python import configs\nfrom tensorflow_lattice.python import premade\nfrom tensorflow_lattice.python import premade_lib\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nfake_data = {\n    'train_xs': [np.array([1]), np.array([3]), np.array([0])],\n    'train_ys': np.array([1]),\n    'eval_xs': [np.array([2]), np.array([30]), np.array([-3])]\n}\n\nunspecified_feature_configs = [\n    configs.FeatureConfig(\n        name='numerical_1',\n        lattice_size=2,\n        pwl_calibration_input_keypoints=np.linspace(0.0, 1.0, num=10),\n    ),\n    configs.FeatureConfig(\n        name='numerical_2',\n        lattice_size=2,\n        pwl_calibration_input_keypoints=np.linspace(0.0, 1.0, num=10),\n    ),\n    configs.FeatureConfig(\n        name='categorical',\n        lattice_size=2,\n        num_buckets=2,\n        monotonicity=[('0.0', '1.0')],\n        vocabulary_list=['0.0', '1.0'],\n    ),\n]\n\nspecified_feature_configs = [\n    configs.FeatureConfig(\n        name='numerical_1',\n        lattice_size=2,\n        pwl_calibration_input_keypoints=np.linspace(0.0, 1.0, num=10),\n    ),\n    configs.FeatureConfig(\n        name='numerical_2',\n        lattice_size=2,\n        pwl_calibration_input_keypoints=np.linspace(0.0, 1.0, num=10),\n    ),\n    configs.FeatureConfig(\n        name='categorical',\n        lattice_size=2,\n        num_buckets=2,\n        monotonicity=[(0, 1)],\n    ),\n]\n\nfeature_configs = [\n    configs.FeatureConfig(\n        name='numerical_1',\n        lattice_size=2,\n        pwl_calibration_input_keypoints=np.linspace(0.0, 1.0, num=10),\n    ),\n    configs.FeatureConfig(\n        name='numerical_2',\n        lattice_size=2,\n        pwl_calibration_input_keypoints=np.linspace(0.0, 1.0, num=10),\n    ),\n    configs.FeatureConfig(\n        name='categorical',\n        lattice_size=2,\n        num_buckets=2,\n        monotonicity=[(0, 1)],\n    ),\n]\n\n\nclass PremadeTest(parameterized.TestCase, tf.test.TestCase):\n  \"\"\"Tests for TFL premade.\"\"\"\n\n  def setUp(self):\n    super(PremadeTest, self).setUp()\n    keras.utils.set_random_seed(42)\n\n    # UCI Statlog (Heart) dataset.\n    heart_csv_file = keras.utils.get_file(\n        'heart.csv',\n        'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv',\n    )\n    heart_df = pd.read_csv(heart_csv_file)\n    thal_vocab_list = ['normal', 'fixed', 'reversible']\n    heart_df['thal'] = heart_df['thal'].map(\n        {v: i for i, v in enumerate(thal_vocab_list)})\n    heart_df = heart_df.astype(float)\n\n    heart_train_size = int(len(heart_df) * 0.8)\n    heart_train_dict = dict(heart_df[:heart_train_size])\n    heart_test_dict = dict(heart_df[heart_train_size:])\n\n    # Features:\n    # - age\n    # - sex\n    # - cp        chest pain type (4 values)\n    # - trestbps  resting blood pressure\n    # - chol      serum cholestoral in mg/dl\n    # - fbs       fasting blood sugar > 120 mg/dl\n    # - restecg   resting electrocardiographic results (values 0,1,2)\n    # - thalach   maximum heart rate achieved\n    # - exang     exercise induced angina\n    # - oldpeak   ST depression induced by exercise relative to rest\n    # - slope     the slope of the peak exercise ST segment\n    # - ca        number of major vessels (0-3) colored by flourosopy\n    # - thal      normal; fixed defect; reversable defect\n    self.heart_feature_configs = [\n        configs.FeatureConfig(\n            name='age',\n            lattice_size=3,\n            monotonicity='increasing',\n            # We must set the keypoints manually.\n            pwl_calibration_num_keypoints=5,\n            pwl_calibration_input_keypoints='quantiles',\n            pwl_calibration_clip_max=100.,\n            # Per feature regularization.\n            regularizer_configs=[\n                configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n            ],\n        ),\n        configs.FeatureConfig(\n            name='sex',\n            num_buckets=2,\n        ),\n        configs.FeatureConfig(\n            name='cp',\n            monotonicity='increasing',\n            # Keypoints that are uniformly spaced.\n            pwl_calibration_num_keypoints=4,\n            pwl_calibration_input_keypoints='uniform',\n        ),\n        configs.FeatureConfig(\n            name='chol',\n            monotonicity='increasing',\n            # Explicit input keypoints initialization.\n            pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n            # Calibration can be forced to span the full output range\n            # by clamping.\n            pwl_calibration_clamp_min=True,\n            pwl_calibration_clamp_max=True,\n            # Per feature regularization.\n            regularizer_configs=[\n                configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n            ],\n        ),\n        configs.FeatureConfig(\n            name='fbs',\n            # Partial monotonicity: output(0) <= output(1)\n            monotonicity=[(0, 1)],\n            num_buckets=2,\n        ),\n        configs.FeatureConfig(\n            name='trestbps',\n            monotonicity='decreasing',\n            pwl_calibration_num_keypoints=5,\n            pwl_calibration_input_keypoints='quantiles',\n        ),\n        configs.FeatureConfig(\n            name='thalach',\n            monotonicity='decreasing',\n            pwl_calibration_num_keypoints=5,\n            pwl_calibration_input_keypoints='quantiles',\n        ),\n        configs.FeatureConfig(\n            name='restecg',\n            # Partial monotonicity:\n            # output(0) <= output(1), output(0) <= output(2)\n            monotonicity=[(0, 1), (0, 2)],\n            num_buckets=3,\n        ),\n        configs.FeatureConfig(\n            name='exang',\n            # Partial monotonicity: output(0) <= output(1)\n            monotonicity=[(0, 1)],\n            num_buckets=2,\n        ),\n        configs.FeatureConfig(\n            name='oldpeak',\n            monotonicity='increasing',\n            pwl_calibration_num_keypoints=5,\n            pwl_calibration_input_keypoints='quantiles',\n        ),\n        configs.FeatureConfig(\n            name='slope',\n            # Partial monotonicity:\n            # output(0) <= output(1), output(1) <= output(2)\n            monotonicity=[(0, 1), (1, 2)],\n            num_buckets=3,\n        ),\n        configs.FeatureConfig(\n            name='ca',\n            monotonicity='increasing',\n            pwl_calibration_num_keypoints=4,\n            pwl_calibration_input_keypoints='quantiles',\n        ),\n        configs.FeatureConfig(\n            name='thal',\n            # Partial monotonicity:\n            # output(normal) <= output(fixed)\n            # output(normal) <= output(reversible)\n            monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n            num_buckets=3,\n            # We must specify the vocabulary list in order to later set the\n            # monotonicities since we used names and not indices.\n            vocabulary_list=thal_vocab_list,\n        ),\n    ]\n    premade_lib.set_categorical_monotonicities(self.heart_feature_configs)\n\n    # This ordering of input features should match the feature configs.\n    feature_names = [\n        feature_config.name for feature_config in self.heart_feature_configs\n    ]\n    label_name = 'target'\n    self.heart_train_x = [\n        heart_train_dict[feature_name] for feature_name in feature_names\n    ]\n    self.heart_test_x = [\n        heart_test_dict[feature_name] for feature_name in feature_names\n    ]\n    self.heart_train_y = heart_train_dict[label_name]\n    self.heart_test_y = heart_test_dict[label_name]\n\n    # Construct feature map for keypoint calculation.\n    feature_keypoints = premade_lib.compute_feature_keypoints(\n        feature_configs=self.heart_feature_configs, features=heart_train_dict)\n    premade_lib.set_feature_keypoints(\n        feature_configs=self.heart_feature_configs,\n        feature_keypoints=feature_keypoints,\n        add_missing_feature_configs=False)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  class Encoder(json.JSONEncoder):\n\n    def default(self, o):\n      if isinstance(o, np.int32):\n        return int(o)\n      if isinstance(o, np.ndarray):\n        return o.tolist()\n      return json.JSONEncoder.default(self, o)\n\n  def testSetRandomLattices(self):\n    random_model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(unspecified_feature_configs),\n        lattices='random',\n        num_lattices=3,\n        lattice_rank=2,\n        separate_calibrators=True,\n        output_initialization=[-1.0, 1.0])\n\n    premade_lib.set_random_lattice_ensemble(random_model_config)\n    self.assertLen(random_model_config.lattices, 3)\n    self.assertListEqual(\n        [2, 2, 2], [len(lattice) for lattice in random_model_config.lattices])\n\n    specified_model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(specified_feature_configs),\n        lattices=[['numerical_1', 'categorical'],\n                  ['numerical_2', 'categorical']],\n        num_lattices=2,\n        lattice_rank=2,\n        separate_calibrators=True,\n        output_initialization=[-1.0, 1.0])\n\n    with self.assertRaisesRegex(\n        ValueError, 'model_config.lattices must be set to \\'random\\'.'):\n      premade_lib.set_random_lattice_ensemble(specified_model_config)\n\n  def testSetCategoricalMonotonicities(self):\n    set_feature_configs = copy.deepcopy(unspecified_feature_configs)\n    premade_lib.set_categorical_monotonicities(set_feature_configs)\n    expectation = [(0, 1)]\n    self.assertListEqual(expectation, set_feature_configs[2].monotonicity)\n\n  def testVerifyConfig(self):\n    unspecified_model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(unspecified_feature_configs),\n        lattices='random',\n        num_lattices=3,\n        lattice_rank=2,\n        separate_calibrators=True,\n        output_initialization=[-1.0, 1.0])\n\n    with self.assertRaisesRegex(\n        ValueError, 'Lattices are not fully specified for ensemble config.'):\n      premade_lib.verify_config(unspecified_model_config)\n    premade_lib.set_random_lattice_ensemble(unspecified_model_config)\n    with self.assertRaisesRegex(\n        ValueError,\n        'Element 0 for list/tuple 0 for feature categorical monotonicity is '\n        'not an index: 0.0'):\n      premade_lib.verify_config(unspecified_model_config)\n    fixed_feature_configs = copy.deepcopy(unspecified_feature_configs)\n    premade_lib.set_categorical_monotonicities(fixed_feature_configs)\n    unspecified_model_config.feature_configs = fixed_feature_configs\n    premade_lib.verify_config(unspecified_model_config)\n\n    specified_model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(specified_feature_configs),\n        lattices=[['numerical_1', 'categorical'],\n                  ['numerical_2', 'categorical']],\n        num_lattices=2,\n        lattice_rank=2,\n        separate_calibrators=True,\n        output_initialization=[-1.0, 1.0])\n\n    premade_lib.verify_config(specified_model_config)\n\n  def testLatticeEnsembleFromConfig(self):\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        lattices=[['numerical_1', 'categorical'],\n                  ['numerical_2', 'categorical']],\n        num_lattices=2,\n        lattice_rank=2,\n        separate_calibrators=True,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-3),\n            configs.RegularizerConfig('torsion', l2=1e-4),\n        ],\n        output_min=-1.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.CalibratedLatticeEnsemble(model_config)\n    loaded_model = premade.CalibratedLatticeEnsemble.from_config(\n        model.get_config(), custom_objects=premade.get_custom_objects())\n    self.assertEqual(\n        json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder),\n        json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder))\n\n  def testLatticeFromConfig(self):\n    model_config = configs.CalibratedLatticeConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_wrinkle', l2=1e-3),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.CalibratedLattice(model_config)\n    loaded_model = premade.CalibratedLattice.from_config(\n        model.get_config(), custom_objects=premade.get_custom_objects())\n    self.assertEqual(\n        json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder),\n        json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder))\n\n  def testLatticeSimplexFromConfig(self):\n    model_config = configs.CalibratedLatticeConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_wrinkle', l2=1e-3),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        output_min=0.0,\n        output_max=1.0,\n        interpolation='simplex',\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.CalibratedLattice(model_config)\n    loaded_model = premade.CalibratedLattice.from_config(\n        model.get_config(), custom_objects=premade.get_custom_objects())\n    self.assertEqual(\n        json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder),\n        json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder))\n\n  def testLinearFromConfig(self):\n    model_config = configs.CalibratedLinearConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-4),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        use_bias=True,\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.CalibratedLinear(model_config)\n    loaded_model = premade.CalibratedLinear.from_config(\n        model.get_config(), custom_objects=premade.get_custom_objects())\n    self.assertEqual(\n        json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder),\n        json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder))\n\n  def testAggregateFromConfig(self):\n    model_config = configs.AggregateFunctionConfig(\n        feature_configs=feature_configs,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-4),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        middle_calibration=True,\n        middle_monotonicity='increasing',\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.AggregateFunction(model_config)\n    loaded_model = premade.AggregateFunction.from_config(\n        model.get_config(), custom_objects=premade.get_custom_objects())\n    self.assertEqual(\n        json.dumps(model.get_config(), sort_keys=True, cls=self.Encoder),\n        json.dumps(loaded_model.get_config(), sort_keys=True, cls=self.Encoder))\n\n  @parameterized.parameters(\n      ('hypercube', 'all_vertices', 0, 0.85),\n      ('simplex', 'all_vertices', 0, 0.89),\n      ('hypercube', 'kronecker_factored', 2, 0.82),\n      ('hypercube', 'kronecker_factored', 4, 0.82),\n  )\n  def testCalibratedLatticeEnsembleCrystals(self, interpolation,\n                                            parameterization, num_terms,\n                                            expected_minimum_auc):\n    # Construct model.\n    self._ResetAllBackends()\n    crystals_feature_configs = copy.deepcopy(self.heart_feature_configs)\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        regularizer_configs=[\n            configs.RegularizerConfig(name='torsion', l2=1e-4),\n            configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n        ],\n        feature_configs=crystals_feature_configs,\n        lattices='crystals',\n        num_lattices=6,\n        lattice_rank=5,\n        interpolation=interpolation,\n        parameterization=parameterization,\n        num_terms=num_terms,\n        separate_calibrators=True,\n        output_calibration=False,\n        output_initialization=[-2, 2],\n    )\n    if parameterization == 'kronecker_factored':\n      model_config.regularizer_configs = None\n      for feature_config in model_config.feature_configs:\n        feature_config.lattice_size = 2\n        feature_config.unimodality = 'none'\n        feature_config.reflects_trust_in = None\n        feature_config.dominates = None\n        feature_config.regularizer_configs = None\n    # Perform prefitting steps.\n    prefitting_model_config = premade_lib.construct_prefitting_model_config(\n        model_config)\n    prefitting_model = premade.CalibratedLatticeEnsemble(\n        prefitting_model_config)\n    prefitting_model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        optimizer=keras.optimizers.legacy.Adam(0.01),\n    )\n    prefitting_model.fit(\n        self.heart_train_x,\n        self.heart_train_y,\n        batch_size=100,\n        epochs=50,\n        verbose=False)\n    premade_lib.set_crystals_lattice_ensemble(model_config,\n                                              prefitting_model_config,\n                                              prefitting_model)\n    # Construct and train final model\n    model = premade.CalibratedLatticeEnsemble(model_config)\n    model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        metrics=keras.metrics.AUC(from_logits=True),\n        optimizer=keras.optimizers.legacy.Adam(0.01),\n    )\n    model.fit(\n        self.heart_train_x,\n        self.heart_train_y,\n        batch_size=100,\n        epochs=200,\n        verbose=False)\n    results = model.evaluate(\n        self.heart_test_x, self.heart_test_y, verbose=False)\n    logging.info('Calibrated lattice ensemble crystals classifier results:')\n    logging.info(results)\n    self.assertGreater(results[1], expected_minimum_auc)\n\n  @parameterized.parameters(\n      ('hypercube', 'all_vertices', 0, 0.85),\n      ('simplex', 'all_vertices', 0, 0.88),\n      ('hypercube', 'kronecker_factored', 2, 0.82),\n      ('hypercube', 'kronecker_factored', 4, 0.82),\n  )\n  def testCalibratedLatticeEnsembleRTL(self, interpolation, parameterization,\n                                       num_terms, expected_minimum_auc):\n    # Construct model.\n    self._ResetAllBackends()\n    rtl_feature_configs = copy.deepcopy(self.heart_feature_configs)\n    for feature_config in rtl_feature_configs:\n      feature_config.lattice_size = 2\n      feature_config.unimodality = 'none'\n      feature_config.reflects_trust_in = None\n      feature_config.dominates = None\n      feature_config.regularizer_configs = None\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        regularizer_configs=[\n            configs.RegularizerConfig(name='torsion', l2=1e-4),\n            configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n        ],\n        feature_configs=rtl_feature_configs,\n        lattices='rtl_layer',\n        num_lattices=6,\n        lattice_rank=5,\n        interpolation=interpolation,\n        parameterization=parameterization,\n        num_terms=num_terms,\n        separate_calibrators=True,\n        output_calibration=False,\n        output_initialization=[-2, 2],\n    )\n    # We must remove all regularization if using 'kronecker_factored'.\n    if parameterization == 'kronecker_factored':\n      model_config.regularizer_configs = None\n    # Construct and train final model\n    model = premade.CalibratedLatticeEnsemble(model_config)\n    model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        metrics=keras.metrics.AUC(from_logits=True),\n        optimizer=keras.optimizers.legacy.Adam(0.01),\n    )\n    model.fit(\n        self.heart_train_x,\n        self.heart_train_y,\n        batch_size=100,\n        epochs=200,\n        verbose=False)\n    results = model.evaluate(\n        self.heart_test_x, self.heart_test_y, verbose=False)\n    logging.info('Calibrated lattice ensemble rtl classifier results:')\n    logging.info(results)\n    self.assertGreater(results[1], expected_minimum_auc)\n\n  @parameterized.parameters(\n      ('hypercube', 'all_vertices', 0, 0.81),\n      ('simplex', 'all_vertices', 0, 0.81),\n      ('hypercube', 'kronecker_factored', 2, 0.77),\n      ('hypercube', 'kronecker_factored', 4, 0.77),\n  )\n  def testCalibratedLattice(self, interpolation, parameterization, num_terms,\n                            expected_minimum_auc):\n    # Construct model configuration.\n    self._ResetAllBackends()\n    lattice_feature_configs = copy.deepcopy(self.heart_feature_configs[:5])\n    model_config = configs.CalibratedLatticeConfig(\n        feature_configs=lattice_feature_configs,\n        interpolation=interpolation,\n        parameterization=parameterization,\n        num_terms=num_terms,\n        regularizer_configs=[\n            configs.RegularizerConfig(name='torsion', l2=1e-4),\n            configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n        ],\n        output_calibration=False,\n        output_initialization=[-2, 2],\n    )\n    if parameterization == 'kronecker_factored':\n      model_config.regularizer_configs = None\n      for feature_config in model_config.feature_configs:\n        feature_config.lattice_size = 2\n        feature_config.unimodality = 'none'\n        feature_config.reflects_trust_in = None\n        feature_config.dominates = None\n        feature_config.regularizer_configs = None\n    # Construct and train final model\n    model = premade.CalibratedLattice(model_config)\n    model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        metrics=keras.metrics.AUC(from_logits=True),\n        optimizer=keras.optimizers.legacy.Adam(0.01),\n    )\n    model.fit(\n        self.heart_train_x[:5],\n        self.heart_train_y,\n        batch_size=100,\n        epochs=200,\n        verbose=False)\n    results = model.evaluate(\n        self.heart_test_x[:5], self.heart_test_y, verbose=False)\n    logging.info('Calibrated lattice classifier results:')\n    logging.info(results)\n    self.assertGreater(results[1], expected_minimum_auc)\n\n  def testLearnedCalibrationInputKeypoints(self):\n    # First let's try a CalibratedLatticeEnsemble\n    self._ResetAllBackends()\n    learned_keypoints_feature_configs = copy.deepcopy(\n        self.heart_feature_configs)\n    for feature_config in learned_keypoints_feature_configs:\n      feature_config.pwl_calibration_input_keypoints_type = 'learned_interior'\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        regularizer_configs=[\n            configs.RegularizerConfig(name='torsion', l2=1e-4),\n            configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n        ],\n        feature_configs=learned_keypoints_feature_configs,\n        lattices='random',\n        num_lattices=6,\n        lattice_rank=5,\n        interpolation='hypercube',\n        separate_calibrators=True,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.],\n        output_calibration_input_keypoints_type='learned_interior',\n    )\n    premade_lib.set_random_lattice_ensemble(model_config)\n    # Construct and train final model\n    model = premade.CalibratedLatticeEnsemble(model_config)\n    model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        metrics=keras.metrics.AUC(from_logits=True),\n        optimizer=keras.optimizers.legacy.Adam(0.01),\n    )\n    model.fit(\n        self.heart_train_x,\n        self.heart_train_y,\n        batch_size=100,\n        epochs=200,\n        verbose=False)\n    results = model.evaluate(\n        self.heart_test_x, self.heart_test_y, verbose=False)\n    logging.info('Calibrated random lattice ensemble classifier results:')\n    logging.info(results)\n    self.assertGreater(results[1], 0.82)\n\n    # Now let's try a CalibratedLattice\n    self._ResetAllBackends()\n    model_config = configs.CalibratedLatticeConfig(\n        feature_configs=learned_keypoints_feature_configs[:5],\n        interpolation='hypercube',\n        regularizer_configs=[\n            configs.RegularizerConfig(name='torsion', l2=1e-4),\n            configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n        ],\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.],\n        output_calibration_input_keypoints_type='learned_interior',\n    )\n    # Construct and train final model\n    model = premade.CalibratedLattice(model_config)\n    model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        metrics=keras.metrics.AUC(from_logits=True),\n        optimizer=keras.optimizers.legacy.Adam(0.01),\n    )\n    model.fit(\n        self.heart_train_x[:5],\n        self.heart_train_y,\n        batch_size=100,\n        epochs=200,\n        verbose=False)\n    results = model.evaluate(\n        self.heart_test_x[:5], self.heart_test_y, verbose=False)\n    logging.info('Calibrated lattice classifier results:')\n    logging.info(results)\n    self.assertGreater(results[1], 0.79)\n\n  @parameterized.parameters(\n      ('all_vertices', 0),\n      ('kronecker_factored', 2),\n  )\n  def testLatticeEnsembleH5FormatSaveLoad(self, parameterization, num_terms):\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        lattices=[['numerical_1', 'categorical'],\n                  ['numerical_2', 'categorical']],\n        num_lattices=2,\n        lattice_rank=2,\n        parameterization=parameterization,\n        num_terms=num_terms,\n        separate_calibrators=True,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-3),\n            configs.RegularizerConfig('torsion', l2=1e-4),\n        ],\n        output_min=-1.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    if parameterization == 'kronecker_factored':\n      model_config.regularizer_configs = None\n      for feature_config in model_config.feature_configs:\n        feature_config.lattice_size = 2\n        feature_config.unimodality = 'none'\n        feature_config.reflects_trust_in = None\n        feature_config.dominates = None\n        feature_config.regularizer_configs = None\n    model = premade.CalibratedLatticeEnsemble(model_config)\n    # Compile and fit model.\n    model.compile(loss='mse', optimizer=keras.optimizers.legacy.Adam(0.1))\n    model.fit(fake_data['train_xs'], fake_data['train_ys'])\n    # Save model using H5 format.\n    with tempfile.NamedTemporaryFile(suffix='.h5') as f:\n      keras.models.save_model(model, f.name)\n      loaded_model = keras.models.load_model(\n          f.name, custom_objects=premade.get_custom_objects()\n      )\n      self.assertAllClose(\n          model.predict(fake_data['eval_xs']),\n          loaded_model.predict(fake_data['eval_xs']))\n\n  @parameterized.parameters(\n      ('all_vertices', 0),\n      ('kronecker_factored', 2),\n  )\n  def testLatticeEnsembleRTLH5FormatSaveLoad(self, parameterization, num_terms):\n    rtl_feature_configs = copy.deepcopy(feature_configs)\n    for feature_config in rtl_feature_configs:\n      feature_config.lattice_size = 2\n      feature_config.unimodality = 'none'\n      feature_config.reflects_trust_in = None\n      feature_config.dominates = None\n      feature_config.regularizer_configs = None\n    model_config = configs.CalibratedLatticeEnsembleConfig(\n        feature_configs=copy.deepcopy(rtl_feature_configs),\n        lattices='rtl_layer',\n        num_lattices=2,\n        lattice_rank=2,\n        parameterization=parameterization,\n        num_terms=num_terms,\n        separate_calibrators=True,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-3),\n            configs.RegularizerConfig('torsion', l2=1e-4),\n        ],\n        output_min=-1.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    if parameterization == 'kronecker_factored':\n      model_config.regularizer_configs = None\n    model = premade.CalibratedLatticeEnsemble(model_config)\n    # Compile and fit model.\n    model.compile(loss='mse', optimizer=keras.optimizers.legacy.Adam(0.1))\n    model.fit(fake_data['train_xs'], fake_data['train_ys'])\n    # Save model using H5 format.\n    with tempfile.NamedTemporaryFile(suffix='.h5') as f:\n      keras.models.save_model(model, f.name)\n      loaded_model = keras.models.load_model(\n          f.name, custom_objects=premade.get_custom_objects()\n      )\n      self.assertAllClose(\n          model.predict(fake_data['eval_xs']),\n          loaded_model.predict(fake_data['eval_xs']))\n\n  @parameterized.parameters(\n      ('all_vertices', 0),\n      ('kronecker_factored', 2),\n  )\n  def testLatticeH5FormatSaveLoad(self, parameterization, num_terms):\n    model_config = configs.CalibratedLatticeConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        parameterization=parameterization,\n        num_terms=num_terms,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_wrinkle', l2=1e-3),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    if parameterization == 'kronecker_factored':\n      model_config.regularizer_configs = None\n      for feature_config in model_config.feature_configs:\n        feature_config.lattice_size = 2\n        feature_config.unimodality = 'none'\n        feature_config.reflects_trust_in = None\n        feature_config.dominates = None\n        feature_config.regularizer_configs = None\n    model = premade.CalibratedLattice(model_config)\n    # Compile and fit model.\n    model.compile(loss='mse', optimizer=keras.optimizers.legacy.Adam(0.1))\n    model.fit(fake_data['train_xs'], fake_data['train_ys'])\n    # Save model using H5 format.\n    with tempfile.NamedTemporaryFile(suffix='.h5') as f:\n      keras.models.save_model(model, f.name)\n      loaded_model = keras.models.load_model(\n          f.name, custom_objects=premade.get_custom_objects()\n      )\n      self.assertAllClose(\n          model.predict(fake_data['eval_xs']),\n          loaded_model.predict(fake_data['eval_xs']))\n\n  def testLinearH5FormatSaveLoad(self):\n    model_config = configs.CalibratedLinearConfig(\n        feature_configs=copy.deepcopy(feature_configs),\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-4),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        use_bias=True,\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.CalibratedLinear(model_config)\n    # Compile and fit model.\n    model.compile(loss='mse', optimizer=keras.optimizers.legacy.Adam(0.1))\n    model.fit(fake_data['train_xs'], fake_data['train_ys'])\n    # Save model using H5 format.\n    with tempfile.NamedTemporaryFile(suffix='.h5') as f:\n      keras.models.save_model(model, f.name)\n      loaded_model = keras.models.load_model(\n          f.name, custom_objects=premade.get_custom_objects()\n      )\n      self.assertAllClose(\n          model.predict(fake_data['eval_xs']),\n          loaded_model.predict(fake_data['eval_xs']))\n\n  def testAggregateH5FormatSaveLoad(self):\n    model_config = configs.AggregateFunctionConfig(\n        feature_configs=feature_configs,\n        regularizer_configs=[\n            configs.RegularizerConfig('calib_hessian', l2=1e-4),\n            configs.RegularizerConfig('torsion', l2=1e-3),\n        ],\n        middle_calibration=True,\n        middle_monotonicity='increasing',\n        output_min=0.0,\n        output_max=1.0,\n        output_calibration=True,\n        output_calibration_num_keypoints=5,\n        output_initialization=[-2., -1., 0., 1., 2.])\n    model = premade.AggregateFunction(model_config)\n    # Compile and fit model.\n    model.compile(loss='mse', optimizer=keras.optimizers.legacy.Adam(0.1))\n    model.fit(fake_data['train_xs'], fake_data['train_ys'])\n    # Save model using H5 format.\n    with tempfile.NamedTemporaryFile(suffix='.h5') as f:\n      # Note: because of naming clashes in the optimizer, we cannot include it\n      # when saving in HDF5. The keras team has informed us that we should not\n      # push to support this since SavedModel format is the new default and no\n      # new HDF5 functionality is desired.\n      keras.models.save_model(model, f.name, include_optimizer=False)\n      loaded_model = keras.models.load_model(\n          f.name, custom_objects=premade.get_custom_objects()\n      )\n      self.assertAllClose(\n          model.predict(fake_data['eval_xs']),\n          loaded_model.predict(fake_data['eval_xs']))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/pwl_calibration_layer.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Piecewise linear calibration layer.\n\nKeras implementation of tensorflow lattice pwl calibration layer. Layer takes\nsingle or multi-dimensional input and transforms it using piecewise linear\nfunctions following monotonicity, convexity/concavity and bounds constraints if\nspecified.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import logging\nimport numpy as np\nimport six\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nfrom . import pwl_calibration_lib\nfrom . import utils\n\nINTERPOLATION_KEYPOINTS_NAME = \"interpolation_keypoints\"\nLENGTHS_NAME = \"lengths\"\nMISSING_INPUT_VALUE_NAME = \"missing_input_value\"\nPWL_CALIBRATION_KERNEL_NAME = \"pwl_calibration_kernel\"\nPWL_CALIBRATION_MISSING_OUTPUT_NAME = \"pwl_calibration_missing_output\"\nINTERPOLATION_LOGITS_NAME = \"interpolation_logits\"\n\n\nclass PWLCalibration(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Piecewise linear calibration layer.\n\n  Layer takes input of shape `(batch_size, units)` or `(batch_size, 1)` and\n  transforms it using `units` number of piecewise linear functions following\n  monotonicity, convexity and bounds constraints if specified. If multi\n  dimensional input is provides, each output will be for the corresponding\n  input, otherwise all PWL functions will act on the same input. All units share\n  the same layer configuration, but each has their separate set of trained\n  parameters.\n\n  See `tfl.layers.ParallelCombination` layer for using PWLCalibration layer\n  within Sequential Keras models.\n\n  Input shape:\n  Single input should be a rank-2 tensor with shape: `(batch_size, units)` or\n  `(batch_size, 1)`. The input can also be a list of two tensors of the same\n  shape where the first tensor is the regular input tensor and the second is the\n  `is_missing` tensor. In the `is_missing` tensor, 1.0 represents missing input\n  and 0.0 represents available input.\n\n  Output shape:\n  If units > 1 and split_outputs is True, a length `units` list of Rank-2\n    tensors with shape `(batch_size, 1)`. Otherwise, a Rank-2 tensor with shape:\n    `(batch_size, units)`\n\n  Attributes:\n    - All `__init__` arguments.\n    kernel: TF variable which stores weights of piecewise linear function.\n    missing_output: TF variable which stores output learned for missing input.\n      Or TF Constant which stores `missing_output_value` if one is provided.\n      Available only if `impute_missing` is True.\n\n  Example:\n\n  ```python\n  calibrator = tfl.layers.PWLCalibration(\n      # Key-points of piecewise-linear function.\n      input_keypoints=np.linspace(1., 4., num=4),\n      # Output can be bounded, e.g. when this layer feeds into a lattice.\n      output_min=0.0,\n      output_max=2.0,\n      # You can specify monotonicity and other shape constraints for the layer.\n      monotonicity='increasing',\n      # You can specify TFL regularizers as tuple ('regularizer name', l1, l2).\n      # You can also pass any keras Regularizer object.\n      kernel_regularizer=('hessian', 0.0, 1e-4),\n  )\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               input_keypoints,\n               units=1,\n               output_min=None,\n               output_max=None,\n               clamp_min=False,\n               clamp_max=False,\n               monotonicity=\"none\",\n               convexity=\"none\",\n               is_cyclic=False,\n               kernel_initializer=\"equal_heights\",\n               kernel_regularizer=None,\n               impute_missing=False,\n               missing_input_value=None,\n               missing_output_value=None,\n               num_projection_iterations=8,\n               split_outputs=False,\n               input_keypoints_type=\"fixed\",\n               **kwargs):\n    # pyformat: disable\n    \"\"\"Initializes an instance of `PWLCalibration`.\n\n    Args:\n      input_keypoints: Ordered list of keypoints of piecewise linear function.\n        Can be anything accepted by tf.convert_to_tensor().\n      units: Output dimension of the layer. See class comments for details.\n      output_min: Minimum output of calibrator.\n      output_max: Maximum output of calibrator.\n      clamp_min: For monotonic calibrators ensures that output_min is reached.\n      clamp_max: For monotonic calibrators ensures that output_max is reached.\n      monotonicity: Constraints piecewise linear function to be monotonic using\n        'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or\n        -1 to indicate decreasing monotonicity and 'none' or 0 to indicate no\n        monotonicity constraints.\n      convexity: Constraints piecewise linear function to be convex or concave.\n        Convexity is indicated by 'convex' or 1, concavity is indicated by\n        'concave' or -1, 'none' or 0 indicates no convexity/concavity\n        constraints.\n        Concavity together with increasing monotonicity as well as convexity\n        together with decreasing monotonicity results in diminishing return\n        constraints.\n        Consider increasing the value of `num_projection_iterations` if\n        convexity is specified, especially with larger number of keypoints.\n      is_cyclic: Whether the output for last keypoint should be identical to\n        output for first keypoint. This is useful for features such as\n        \"time of day\" or \"degree of turn\". If inputs are discrete and exactly\n        match keypoints then is_cyclic will have an effect only if TFL\n        regularizers are being used.\n      kernel_initializer: None or one of:\n        - String `\"equal_heights\"`: For pieces of pwl function to have equal\n          heights.\n        - String `\"equal_slopes\"`: For pieces of pwl function to have equal\n          slopes.\n        - Any Keras initializer object. If you are passing such object make sure\n          that you know how layer stores its data.\n      kernel_regularizer: None or single element or list of following:\n        - Tuple `(\"laplacian\", l1, l2)` where `l1` and `l2` are floats which\n          represent corresponding regularization amount for Laplacian\n          regularizer. It penalizes the first derivative to make the function\n          more constant. See `tfl.pwl_calibration.LaplacianRegularizer` for more\n          details.\n        - Tuple `(\"hessian\", l1, l2)` where `l1` and `l2` are floats which\n          represent corresponding regularization amount for Hessian regularizer.\n          It penalizes the second derivative to make the function more linear.\n          See `tfl.pwl_calibration.HessianRegularizer` for more details.\n        - Tuple `(\"wrinkle\", l1, l2)` where `l1` and `l2` are floats which\n          represent corresponding regularization amount for wrinkle regularizer.\n          It penalizes the third derivative to make the function more smooth.\n          See 'tfl.pwl_calibration.WrinkleRegularizer` for more details.\n        - Any Keras regularizer object.\n      impute_missing: Whether to learn an output for cases where input data is\n        missing. If set to True, either `missing_input_value` should be\n        initialized, or the `call()` method should get pair of tensors. See\n        class input shape description for more details.\n      missing_input_value: If set, all inputs which are equal to this value will\n        be considered as missing. Can not be set if `impute_missing` is False.\n      missing_output_value: If set, instead of learning output for missing\n        inputs, simply maps them into this value. Can not be set if\n        `impute_missing` is False.\n      num_projection_iterations: Number of iterations of the Dykstra's\n        projection algorithm. Constraints are strictly satisfied at the end of\n        each update, but the update will be closer to a true L2 projection with\n        higher number of iterations. See\n        `tfl.pwl_calibration_lib.project_all_constraints` for more details.\n      split_outputs: Whether to split the output tensor into a list of\n        outputs for each unit. Ignored if units < 2.\n      input_keypoints_type: One of \"fixed\" or \"learned_interior\". If\n        \"learned_interior\", keypoints are initialized to the values in\n        `input_keypoints` but then allowed to vary during training, with the\n        exception of the first and last keypoint location which are fixed.\n        Convexity can only be imposed with \"fixed\".\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: If layer hyperparameters are invalid.\n    \"\"\"\n    # pyformat: enable\n    super(PWLCalibration, self).__init__(**kwargs)\n\n    pwl_calibration_lib.verify_hyperparameters(\n        input_keypoints=input_keypoints,\n        output_min=output_min,\n        output_max=output_max,\n        monotonicity=monotonicity,\n        convexity=convexity,\n        is_cyclic=is_cyclic,\n        input_keypoints_type=input_keypoints_type)\n    if missing_input_value is not None and not impute_missing:\n      raise ValueError(\"'missing_input_value' is specified, but \"\n                       \"'impute_missing' is set to False. \"\n                       \"'missing_input_value': \" + str(missing_input_value))\n    if missing_output_value is not None and not impute_missing:\n      raise ValueError(\"'missing_output_value' is specified, but \"\n                       \"'impute_missing' is set to False. \"\n                       \"'missing_output_value': \" + str(missing_output_value))\n    if input_keypoints is None:\n      raise ValueError(\"'input_keypoints' can't be None\")\n    if monotonicity is None:\n      raise ValueError(\"'monotonicity' can't be None. Did you mean '0'?\")\n    if convexity not in (\"none\",\n                         0) and input_keypoints_type == \"learned_interior\":\n      raise ValueError(\"Cannot set input_keypoints_type to 'learned_interior'\"\n                       \" and impose convexity constraints.\")\n\n    self.input_keypoints = input_keypoints\n    self.units = units\n    self.output_min = output_min\n    self.output_max = output_max\n    self.clamp_min = clamp_min\n    self.clamp_max = clamp_max\n    (self._output_init_min, self._output_init_max, self._output_min_constraints,\n     self._output_max_constraints\n    ) = pwl_calibration_lib.convert_all_constraints(self.output_min,\n                                                    self.output_max,\n                                                    self.clamp_min,\n                                                    self.clamp_max)\n\n    self.monotonicity = monotonicity\n    self.convexity = convexity\n    self.is_cyclic = is_cyclic\n\n    if kernel_initializer == \"equal_heights\":\n      self.kernel_initializer = UniformOutputInitializer(\n          output_min=self._output_init_min,\n          output_max=self._output_init_max,\n          monotonicity=self.monotonicity)\n    elif kernel_initializer == \"equal_slopes\":\n      self.kernel_initializer = UniformOutputInitializer(\n          output_min=self._output_init_min,\n          output_max=self._output_init_max,\n          monotonicity=self.monotonicity,\n          keypoints=self.input_keypoints)\n    else:\n      # Keras deserialization logic must have explicit acceess to all custom\n      # classes. This is standard way to provide such access.\n      with keras.utils.custom_object_scope({\n          \"UniformOutputInitializer\": UniformOutputInitializer,\n      }):\n        self.kernel_initializer = keras.initializers.get(kernel_initializer)\n\n    self.kernel_regularizer = []\n    if kernel_regularizer:\n      if (callable(kernel_regularizer) or\n          (isinstance(kernel_regularizer, tuple) and\n           isinstance(kernel_regularizer[0], six.string_types))):\n        kernel_regularizer = [kernel_regularizer]\n\n      for reg in kernel_regularizer:\n        if isinstance(reg, tuple):\n          (name, l1, l2) = reg\n          if name.lower() == \"laplacian\":\n            self.kernel_regularizer.append(\n                LaplacianRegularizer(l1=l1, l2=l2, is_cyclic=self.is_cyclic))\n          elif name.lower() == \"hessian\":\n            self.kernel_regularizer.append(\n                HessianRegularizer(l1=l1, l2=l2, is_cyclic=self.is_cyclic))\n          elif name.lower() == \"wrinkle\":\n            self.kernel_regularizer.append(\n                WrinkleRegularizer(l1=l1, l2=l2, is_cyclic=self.is_cyclic))\n          else:\n            raise ValueError(\"Unknown custom lattice regularizer: %s\" % reg)\n        else:\n          # This is needed for Keras deserialization logic to be aware of our\n          # custom objects.\n          with keras.utils.custom_object_scope({\n              \"LaplacianRegularizer\": LaplacianRegularizer,\n              \"HessianRegularizer\": HessianRegularizer,\n              \"WrinkleRegularizer\": WrinkleRegularizer,\n          }):\n            self.kernel_regularizer.append(keras.regularizers.get(reg))\n\n    self.impute_missing = impute_missing\n    self.missing_input_value = missing_input_value\n    self.missing_output_value = missing_output_value\n    self.num_projection_iterations = num_projection_iterations\n    self.split_outputs = split_outputs\n    self.input_keypoints_type = input_keypoints_type\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    input_keypoints = np.array(self.input_keypoints)\n    # Don't need last keypoint for interpolation because we need only beginnings\n    # of intervals.\n    if self.input_keypoints_type == \"fixed\":\n      self._interpolation_keypoints = tf.constant(\n          input_keypoints[:-1],\n          dtype=self.dtype,\n          name=INTERPOLATION_KEYPOINTS_NAME)\n      self._lengths = tf.constant(\n          input_keypoints[1:] - input_keypoints[:-1],\n          dtype=self.dtype,\n          name=LENGTHS_NAME)\n    else:\n      self._keypoint_min = input_keypoints[0]\n      self._keypoint_range = input_keypoints[-1] - input_keypoints[0]\n      # Logits are initialized such that they will recover the scaled keypoint\n      # gaps in input_keypoints.\n      initial_logits = np.log(\n          (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range)\n      tiled_logits = np.tile(initial_logits, self.units)\n      self.interpolation_logits = self.add_weight(\n          INTERPOLATION_LOGITS_NAME,\n          shape=[self.units, len(input_keypoints) - 1],\n          initializer=tf.constant_initializer(tiled_logits),\n          dtype=self.dtype)\n\n    constraints = PWLCalibrationConstraints(\n        monotonicity=self.monotonicity,\n        convexity=self.convexity,\n        lengths=self._lengths if self.input_keypoints_type == \"fixed\" else None,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        output_min_constraints=self._output_min_constraints,\n        output_max_constraints=self._output_max_constraints,\n        num_projection_iterations=self.num_projection_iterations)\n\n    if not self.kernel_regularizer:\n      kernel_reg = None\n    elif len(self.kernel_regularizer) == 1:\n      kernel_reg = self.kernel_regularizer[0]\n    else:\n      # Keras interface assumes only one regularizer, so summ all regularization\n      # losses which we have.\n      kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer])\n\n    # If 'is_cyclic' is specified - last weight will be computed from previous\n    # weights in order to connect last keypoint with first.\n    num_weights = input_keypoints.size - self.is_cyclic\n\n    # PWL calibration layer kernel is units-column matrix. First row of matrix\n    # represents bias. All remaining represent delta in y-value compare to\n    # previous point. Aka heights of segments.\n    self.kernel = self.add_weight(\n        PWL_CALIBRATION_KERNEL_NAME,\n        shape=[num_weights, self.units],\n        initializer=self.kernel_initializer,\n        regularizer=kernel_reg,\n        constraint=constraints,\n        dtype=self.dtype)\n\n    if self.kernel_regularizer and not tf.executing_eagerly():\n      # Keras has its own mechanism to handle regularization losses which\n      # does not use GraphKeys, but we want to also add losses to graph keys so\n      # they are easily accessable when layer is being used outside of Keras.\n      # Adding losses to GraphKeys will not interfer with Keras.\n      for reg in self.kernel_regularizer:\n        tf.compat.v1.add_to_collection(\n            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg(self.kernel))\n\n    if self.impute_missing:\n      if self.missing_input_value is not None:\n        self._missing_input_value_tensor = tf.constant(\n            self.missing_input_value,\n            dtype=self.dtype,\n            name=MISSING_INPUT_VALUE_NAME)\n      else:\n        self._missing_input_value_tensor = None\n\n      if self.missing_output_value is not None:\n        self.missing_output = tf.constant(\n            self.missing_output_value, shape=[1, self.units], dtype=self.dtype)\n      else:\n        missing_init = (self._output_init_min + self._output_init_max) / 2.0\n        missing_constraints = NaiveBoundsConstraints(\n            lower_bound=self.output_min, upper_bound=self.output_max)\n        self.missing_output = self.add_weight(\n            PWL_CALIBRATION_MISSING_OUTPUT_NAME,\n            shape=[1, self.units],\n            initializer=keras.initializers.Constant(value=missing_init),\n            constraint=missing_constraints,\n            dtype=self.dtype)\n\n    super(PWLCalibration, self).build(input_shape)\n\n  def call(self, inputs):\n    \"\"\"Standard Keras call() method..\n\n    Args:\n      inputs: Either input tensor or list of 2 elements: input tensor and\n        `is_missing` tensor.\n\n    Returns:\n      Calibrated input tensor.\n\n    Raises:\n      ValueError: If `is_missing` tensor specified incorrectly.\n    \"\"\"\n    is_missing = None\n    if isinstance(inputs, list):\n      # Only 2 element lists are allowed. When such list is given - second\n      # element represents 'is_missing' tensor encoded as float value.\n      if not self.impute_missing:\n        raise ValueError(\"Multiple inputs for PWLCalibration layer assume \"\n                         \"regular input tensor and 'is_missing' tensor, but \"\n                         \"this instance of a layer is not configured to handle \"\n                         \"missing value. See 'impute_missing' parameter.\")\n      if len(inputs) > 2:\n        raise ValueError(\"Multiple inputs for PWLCalibration layer assume \"\n                         \"normal input tensor and 'is_missing' tensor, but more\"\n                         \" than 2 tensors given. 'inputs': \" + str(inputs))\n      if len(inputs) == 2:\n        inputs, is_missing = inputs\n        if is_missing.shape.as_list() != inputs.shape.as_list():\n          raise ValueError(\n              \"is_missing shape %s does not match inputs shape %s for \"\n              \"PWLCalibration layer\" %\n              (str(is_missing.shape), str(inputs.shape)))\n      else:\n        [inputs] = inputs\n    if len(inputs.shape) != 2 or (inputs.shape[1] != self.units and\n                                  inputs.shape[1] != 1):\n      raise ValueError(\"Shape of input tensor for PWLCalibration layer must be \"\n                       \"[-1, units] or [-1, 1]. It is: \" + str(inputs.shape))\n\n    if self.input_keypoints_type == \"fixed\":\n      keypoints_dtype = self._interpolation_keypoints.dtype\n    else:\n      keypoints_dtype = self.interpolation_logits.dtype\n    if inputs.dtype != keypoints_dtype:\n      raise ValueError(\"dtype(%s) of input to PWLCalibration layer does not \"\n                       \"correspond to dtype(%s) of keypoints. You can enforce \"\n                       \"dtype of keypoints by explicitly providing 'dtype' \"\n                       \"parameter to layer constructor or by passing keypoints \"\n                       \"in such format which by default will be converted into \"\n                       \"desired one.\" % (inputs.dtype, keypoints_dtype))\n\n    # Here is calibration. Everything else is handling of missing.\n    if inputs.shape[1] > 1 or (self.input_keypoints_type == \"learned_interior\"\n                               and self.units > 1):\n      # Interpolation will have shape [batch_size, units, weights] in these\n      # cases. To prepare for that, we add a dimension to the input here to get\n      # shape [batch_size, units, 1] or [batch_size, 1, 1] if 1d input.\n      inputs_to_calibration = tf.expand_dims(inputs, -1)\n    else:\n      inputs_to_calibration = inputs\n    if self.input_keypoints_type == \"learned_interior\":\n      self._lengths = tf.multiply(\n          tf.nn.softmax(self.interpolation_logits, axis=1),\n          self._keypoint_range,\n          name=LENGTHS_NAME)\n      self._interpolation_keypoints = tf.add(\n          tf.cumsum(self._lengths, axis=1, exclusive=True),\n          self._keypoint_min,\n          name=INTERPOLATION_KEYPOINTS_NAME)\n    interpolation_weights = pwl_calibration_lib.compute_interpolation_weights(\n        inputs_to_calibration, self._interpolation_keypoints, self._lengths)\n    if self.is_cyclic:\n      # Need to add such last height to make all heights to sum up to 0.0 in\n      # order to make calibrator cyclic.\n      bias_and_heights = tf.concat(\n          [self.kernel, -tf.reduce_sum(self.kernel[1:], axis=0, keepdims=True)],\n          axis=0)\n    else:\n      bias_and_heights = self.kernel\n\n    # bias_and_heights has shape [weight, units].\n    if len(interpolation_weights.shape) > 2:\n      # Multi dim input has interpolation shape [batch_size, units, weights].\n      result = tf.reduce_sum(\n          interpolation_weights * tf.transpose(bias_and_heights), axis=-1)\n    else:\n      # Single dim input has interpolation shape [batch_size, weights].\n      result = tf.matmul(interpolation_weights, bias_and_heights)\n\n    if self.impute_missing:\n      if is_missing is None:\n        if self.missing_input_value is None:\n          raise ValueError(\"PWLCalibration layer is configured to impute \"\n                           \"missing but no 'missing_input_value' specified and \"\n                           \"'is_missing' tensor is not given.\")\n        assert self._missing_input_value_tensor is not None\n        is_missing = tf.cast(\n            tf.equal(inputs, self._missing_input_value_tensor),\n            dtype=self.dtype)\n      result = is_missing * self.missing_output + (1.0 - is_missing) * result\n\n    if self.units > 1 and self.split_outputs:\n      result = tf.split(result, self.units, axis=1)\n\n    return result\n\n  def compute_output_shape(self, input_shape):\n    \"\"\"Standard Keras compute_output_shape() method.\"\"\"\n    del input_shape\n    if self.units > 1 and self.split_outputs:\n      return [(None, 1)] * self.units\n    else:\n      return (None, self.units)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    config = {\n        \"input_keypoints\": self.input_keypoints,\n        \"units\": self.units,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"clamp_min\": self.clamp_min,\n        \"clamp_max\": self.clamp_max,\n        \"monotonicity\": self.monotonicity,\n        \"convexity\": self.convexity,\n        \"is_cyclic\": self.is_cyclic,\n        \"kernel_initializer\":\n            keras.initializers.serialize(\n                self.kernel_initializer, use_legacy_format=True),\n        \"kernel_regularizer\":\n            [keras.regularizers.serialize(r, use_legacy_format=True)\n             for r in self.kernel_regularizer],\n        \"impute_missing\": self.impute_missing,\n        \"missing_input_value\": self.missing_input_value,\n        \"num_projection_iterations\": self.num_projection_iterations,\n        \"split_outputs\": self.split_outputs,\n        \"input_keypoints_type\": self.input_keypoints_type,\n    }  # pyformat: disable\n    config.update(super(PWLCalibration, self).get_config())\n    return config\n\n  def assert_constraints(self, eps=1e-6):\n    \"\"\"Asserts that layer weights satisfy all constraints.\n\n    In graph mode builds and returns list of assertion ops. Note that ops will\n    be created at the moment when this function is being called.\n    In eager mode directly executes assertions.\n\n    Args:\n      eps: Allowed constraints violation.\n\n    Returns:\n      List of assertion ops in graph mode or immediately asserts in eager mode.\n    \"\"\"\n    # Assert by computing outputs for keypoints and testing them against\n    # constraints.\n    test_inputs = tf.constant(\n        value=self.input_keypoints,\n        dtype=self.dtype,\n        shape=[len(self.input_keypoints), 1])\n    outputs = self.call(test_inputs)\n\n    asserts = pwl_calibration_lib.assert_constraints(\n        outputs=outputs,\n        monotonicity=utils.canonicalize_monotonicity(self.monotonicity),\n        output_min=self.output_min,\n        output_max=self.output_max,\n        clamp_min=self.clamp_min,\n        clamp_max=self.clamp_max,\n        debug_tensors=[\"weights:\", self.kernel],\n        eps=eps)\n\n    if self.impute_missing and self.missing_output_value is None:\n      asserts.append(\n          pwl_calibration_lib.assert_constraints(\n              outputs=self.missing_output,\n              monotonicity=0,\n              output_min=self.output_min,\n              output_max=self.output_max,\n              clamp_min=False,\n              clamp_max=False,\n              debug_tensors=[\"Imputed missing value:\", self.missing_output],\n              eps=eps))\n    return asserts\n\n  def keypoints_outputs(self):\n    \"\"\"Returns tensor of keypoint outputs of shape [num_weights, num_units].\"\"\"\n    kp_outputs = tf.cumsum(self.kernel)\n    if self.is_cyclic:\n      kp_outputs = tf.concat([kp_outputs, kp_outputs[0:1]], axis=0)\n    return kp_outputs\n\n  def keypoints_inputs(self):\n    \"\"\"Returns tensor of keypoint inputs of shape [num_weights, num_units].\"\"\"\n    # We don't store the last keypoint in self._interpolation_keypoints since\n    # it is not needed for training or evaluation, but we re-add it here to\n    # align with the keypoints_outputs function.\n    if self.input_keypoints_type == \"fixed\":\n      all_keypoints = tf.concat([\n          self._interpolation_keypoints,\n          self._interpolation_keypoints[-1:] + self._lengths[-1:]\n      ],\n                                axis=0)\n      return tf.stack([all_keypoints] * self.units, axis=1)\n    else:\n      lengths = tf.nn.softmax(\n          self.interpolation_logits, axis=-1) * self._keypoint_range\n      interpolation_keypoints = tf.cumsum(\n          lengths, axis=-1, exclusive=True) + self._keypoint_min\n      all_keypoints = tf.concat([\n          interpolation_keypoints,\n          interpolation_keypoints[:, -1:] + lengths[:, -1:]\n      ],\n                                axis=1)\n      return tf.transpose(all_keypoints)\n\n\nclass UniformOutputInitializer(keras.initializers.Initializer):\n  # pyformat: disable\n  \"\"\"Initializes PWL calibration layer to represent linear function.\n\n  PWL calibration layer weights are one-d tensor. First element of tensor\n  represents bias. All remaining represent delta in y-value compare to previous\n  point. Aka heights of segments.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self, output_min, output_max, monotonicity, keypoints=None):\n    # pyformat: disable\n    \"\"\"Initializes an instance of `UniformOutputInitializer`.\n\n    Args:\n      output_min: Minimum value of PWL calibration output after initialization.\n      output_max: Maximum value of PWL calibration output after initialization.\n      monotonicity:\n        - if 'none' or 'increasing', the returned function will go from\n          `(input_min, output_min)` to `(input_max, output_max)`.\n        - if 'decreasing', the returned function will go from\n          `(input_min, output_max)` to `(input_max, output_min)`.\n      keypoints:\n        - if not provided (None or []), all pieces of returned function\n          will have equal heights (i.e. `y[i+1] - y[i]` is constant).\n        - if provided, all pieces of returned function will have equal slopes\n          (i.e. `(y[i+1] - y[i]) / (x[i+1] - x[i])` is constant).\n    \"\"\"  # pyformat: enable\n    pwl_calibration_lib.verify_hyperparameters(\n        input_keypoints=keypoints,\n        output_min=output_min,\n        output_max=output_max,\n        monotonicity=monotonicity)\n    self.output_min = output_min\n    self.output_max = output_max\n    self.monotonicity = monotonicity\n    self.keypoints = keypoints\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    \"\"\"Returns weights of PWL calibration layer.\n\n    Args:\n      shape: Must be a collection of the form `(k, units)` where `k >= 2`.\n      dtype: Standard Keras initializer param.\n      partition_info: Standard Keras initializer param.\n\n    Returns:\n      Weights of PWL calibration layer.\n\n    Raises:\n      ValueError: If requested shape is invalid for PWL calibration layer\n        weights.\n    \"\"\"\n    return pwl_calibration_lib.linear_initializer(\n        shape=shape,\n        output_min=self.output_min,\n        output_max=self.output_max,\n        monotonicity=utils.canonicalize_monotonicity(self.monotonicity),\n        keypoints=self.keypoints,\n        dtype=dtype)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"monotonicity\": self.monotonicity,\n        \"keypoints\": self.keypoints,\n    }  # pyformat: disable\n\n\nclass PWLCalibrationConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Monotonicity and bounds constraints for PWL calibration layer.\n\n  Applies an approximate L2 projection to the weights of a PWLCalibration layer\n  such that the result satisfies the specified constraints.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"  # pyformat: enable\n\n  def __init__(\n      self,\n      monotonicity=\"none\",\n      convexity=\"none\",\n      lengths=None,\n      output_min=None,\n      output_max=None,\n      output_min_constraints=pwl_calibration_lib.BoundConstraintsType.NONE,\n      output_max_constraints=pwl_calibration_lib.BoundConstraintsType.NONE,\n      num_projection_iterations=8):\n    \"\"\"Initializes an instance of `PWLCalibration`.\n\n    Args:\n      monotonicity: Same meaning as corresponding parameter of `PWLCalibration`.\n      convexity: Same meaning as corresponding parameter of `PWLCalibration`.\n      lengths: Lengths of pieces of piecewise linear function. Needed only if\n        convexity is specified.\n      output_min: Minimum possible output of pwl function.\n      output_max: Maximum possible output of pwl function.\n      output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n        describing the constraints on the layer's minimum value.\n      output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n        describing the constraints on the layer's maximum value.\n      num_projection_iterations: Same meaning as corresponding parameter of\n        `PWLCalibration`.\n    \"\"\"\n    pwl_calibration_lib.verify_hyperparameters(\n        output_min=output_min,\n        output_max=output_max,\n        monotonicity=monotonicity,\n        convexity=convexity,\n        lengths=lengths)\n    self.monotonicity = monotonicity\n    self.convexity = convexity\n    self.lengths = lengths\n    self.output_min = output_min\n    self.output_max = output_max\n    self.output_min_constraints = output_min_constraints\n    self.output_max_constraints = output_max_constraints\n    self.num_projection_iterations = num_projection_iterations\n\n    canonical_convexity = utils.canonicalize_convexity(self.convexity)\n    canonical_monotonicity = utils.canonicalize_monotonicity(self.monotonicity)\n    if (canonical_convexity != 0 and canonical_monotonicity == 0 and\n        (output_min_constraints != pwl_calibration_lib.BoundConstraintsType.NONE\n         or output_max_constraints !=\n         pwl_calibration_lib.BoundConstraintsType.NONE)):\n      logging.warning(\"Convexity constraints are specified with bounds \"\n                      \"constraints, but without monotonicity. Such combination \"\n                      \"might lead to convexity being slightly violated. \"\n                      \"Consider increasing num_projection_iterations to \"\n                      \"reduce violation.\")\n\n  def __call__(self, w):\n    \"\"\"Applies constraints to w.\"\"\"\n    return pwl_calibration_lib.project_all_constraints(\n        weights=w,\n        monotonicity=utils.canonicalize_monotonicity(self.monotonicity),\n        output_min=self.output_min,\n        output_max=self.output_max,\n        output_min_constraints=self.output_min_constraints,\n        output_max_constraints=self.output_max_constraints,\n        convexity=utils.canonicalize_convexity(self.convexity),\n        lengths=self.lengths,\n        num_projection_iterations=self.num_projection_iterations)\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"monotonicity\": self.monotonicity,\n        \"output_min\": self.output_min,\n        \"output_max\": self.output_max,\n        \"output_min_constraints\": self.output_min_constraints,\n        \"output_max_constraints\": self.output_max_constraints,\n        \"convexity\": self.convexity,\n        \"lengths\": self.lengths,\n        \"num_projection_iterations\": self.num_projection_iterations,\n    }  # pyformat: disable\n\n\nclass NaiveBoundsConstraints(keras.constraints.Constraint):\n  # pyformat: disable\n  \"\"\"Naively clips all elements of tensor to be within bounds.\n\n  This constraint is used only for the weight tensor for missing output value.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"  # pyformat: enable\n\n  def __init__(self, lower_bound=None, upper_bound=None):\n    \"\"\"Initializes an instance of `NaiveBoundsConstraints`.\n\n    Args:\n      lower_bound: Lower bound to clip variable values to.\n      upper_bound: Upper bound to clip variable values to.\n    \"\"\"\n    self.lower_bound = lower_bound\n    self.upper_bound = upper_bound\n\n  def __call__(self, w):\n    \"\"\"Applies constraints to w.\"\"\"\n    if self.lower_bound is not None:\n      w = tf.maximum(w, self.lower_bound)\n    if self.upper_bound is not None:\n      w = tf.minimum(w, self.upper_bound)\n    return w\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"lower_bound\": self.lower_bound,\n        \"upper_bound\": self.upper_bound\n    }  # pyformat: disable\n\n\nclass LaplacianRegularizer(keras.regularizers.Regularizer):\n  # pyformat: disable\n  \"\"\"Laplacian regularizer for PWL calibration layer.\n\n  Calibrator Laplacian regularization penalizes the change in the calibration\n  output. It is defined to be:\n\n  `l1 * ||delta||_1 + l2 * ||delta||_2^2`\n\n  where `delta` is:\n\n  `output_keypoints[1:end] - output_keypoints[0:end-1]`.\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"  # pyformat: enable\n\n  def __init__(self, l1=0.0, l2=0.0, is_cyclic=False):\n    \"\"\"Initializes an instance of `LaplacianRegularizer`.\n\n    Args:\n      l1: l1 regularization amount as float.\n      l2: l2 regularization amount as float.\n      is_cyclic: Whether the first and last keypoints should take the same\n        output value.\n    \"\"\"\n    self.l1 = l1\n    self.l2 = l2\n    self.is_cyclic = is_cyclic\n\n  def __call__(self, x):\n    \"\"\"Returns regularization loss.\n\n    Args:\n      x: Tensor of shape: `(k, units)` which represents weights of PWL\n        calibration layer. First row of weights is bias term. All remaining\n        represent delta in y-value compare to previous point (segment heights).\n    \"\"\"\n    if not self.l1 and not self.l2:\n      return tf.constant(0.0, dtype=x.dtype, shape=())\n    heights = x[1:]\n    if self.is_cyclic:\n      # Need to add such last height to make all heights to sum up to 0.0 in\n      # order to make calibrator cyclic.\n      heights = tf.concat(\n          [heights, -tf.reduce_sum(heights, axis=0, keepdims=True)], axis=0)\n\n    losses = []\n    if self.l1:\n      losses.append(self.l1 * tf.reduce_sum(tf.abs(heights)))\n    if self.l2:\n      losses.append(self.l2 * tf.reduce_sum(tf.square(heights)))\n\n    result = losses[0]\n    if len(losses) == 2:\n      result += losses[1]\n    return result\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"l1\": self.l1,\n        \"l2\": self.l2,\n        \"is_cyclic\": self.is_cyclic,\n    }  # pyformat: disable\n\n\nclass HessianRegularizer(keras.regularizers.Regularizer):\n  # pyformat: disable\n  \"\"\"Hessian regularizer for PWL calibration layer.\n\n  Calibrator hessian regularizer penalizes the change in slopes of linear\n  pieces. It is define to be:\n\n  `l1 * ||nonlinearity||_1 + l2 * ||nonlinearity||_2^2`\n\n  where `nonlinearity` is:\n\n  `2 * output_keypoints[1:end-1] - output_keypoints[0:end-2]\n     - output_keypoints[2:end]`.\n\n  This regularizer is zero when the output_keypoints form a linear function of\n  the index (and not necessarily linear in input values, e.g. when using\n  non-uniform input keypoints).\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"  # pyformat: enable\n\n  def __init__(self, l1=0.0, l2=0.0, is_cyclic=False):\n    \"\"\"Initializes an instance of `HessianRegularizer`.\n\n    Args:\n      l1: l1 regularization amount as float.\n      l2: l2 regularization amount as float.\n      is_cyclic: Whether the first and last keypoints should take the same\n        output value.\n    \"\"\"\n    self.l1 = l1\n    self.l2 = l2\n    self.is_cyclic = is_cyclic\n\n  def __call__(self, x):\n    \"\"\"Returns regularization loss.\n\n    Args:\n      x: Tensor of shape: `(k, units)` which represents weights of PWL\n        calibration layer. First row of weights is bias term. All remaining\n        represent delta in y-value compare to previous point (segment heights).\n    \"\"\"\n    if not self.l1 and not self.l2:\n      return tf.constant(0.0, dtype=x.dtype, shape=())\n\n    if self.is_cyclic:\n      heights = x[1:]\n      heights = tf.concat(\n          [\n              heights,\n              -tf.reduce_sum(heights, axis=0, keepdims=True),\n              heights[0:1],\n          ],\n          axis=0,\n      )\n      nonlinearity = heights[1:] - heights[:-1]\n    else:\n      nonlinearity = x[2:] - x[1:-1]\n\n    losses = []\n    if self.l1:\n      losses.append(self.l1 * tf.reduce_sum(tf.abs(nonlinearity)))\n    if self.l2:\n      losses.append(self.l2 * tf.reduce_sum(tf.square(nonlinearity)))\n\n    result = losses[0]\n    if len(losses) == 2:\n      result += losses[1]\n    return result\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"l1\": self.l1,\n        \"l2\": self.l2,\n        \"is_cyclic\": self.is_cyclic,\n    }  # pyformat: disable\n\n\nclass WrinkleRegularizer(keras.regularizers.Regularizer):\n  # pyformat: disable\n  \"\"\"Wrinkle regularizer for PWL calibration layer.\n\n  Calibrator wrinkle regularization penalizes the change in the second\n  derivative. It is defined to be:\n\n  `l1 * ||third_derivative||_1 + l2 * ||third_derivative||_2^2`\n\n  where `third_derivative` is:\n\n  `3 * output_keypoints[1:end-2] - 3 * output_keypoints[2:end-1]\n   - output_keypoints[0:end-3] + output_keypoints[3:end]`.\n\n  This regularizer is zero when the output_keypoints form a 2nd order polynomial\n  of the index (and not necessarily in input values, e.g. when using\n  non-uniform input keypoints).\n\n  Attributes:\n    - All `__init__` arguments.\n  \"\"\"  # pyformat: enable\n\n  def __init__(self, l1=0.0, l2=0.0, is_cyclic=False):\n    \"\"\"Initializes an instance of `WrinkleRegularizer`.\n\n    Args:\n      l1: l1 regularization amount as float.\n      l2: l2 regularization amount as float.\n      is_cyclic: Whether the first and last keypoints should take the same\n        output value.\n    \"\"\"\n    self.l1 = l1\n    self.l2 = l2\n    self.is_cyclic = is_cyclic\n\n  def __call__(self, x):\n    \"\"\"Returns regularization loss.\n\n    Args:\n      x: Tensor of shape: `(k, units)` which represents weights of PWL\n        calibration layer. First row of weights is bias term. All remaining\n        represent delta in y-value compare to previous point (segment heights).\n    \"\"\"\n    if not self.l1 and not self.l2:\n      return tf.constant(0.0, dtype=x.dtype, shape=())\n    if x.shape[0] < 3:\n      return tf.constant(0.0, dtype=x.dtype, shape=())\n\n    if self.is_cyclic:\n      heights = x[1:]\n      heights = tf.concat(\n          [\n              heights,\n              -tf.reduce_sum(heights, axis=0, keepdims=True),\n              heights[0:1],\n              heights[1:2],\n          ],\n          axis=0,\n      )\n      nonlinearity = heights[1:] - heights[:-1]\n    else:\n      nonlinearity = x[2:] - x[1:-1]\n    wrinkleness = nonlinearity[1:] - nonlinearity[0:-1]\n\n    losses = []\n    if self.l1:\n      losses.append(self.l1 * tf.reduce_sum(tf.abs(wrinkleness)))\n    if self.l2:\n      losses.append(self.l2 * tf.reduce_sum(tf.square(wrinkleness)))\n\n    result = losses[0]\n    if len(losses) == 2:\n      result += losses[1]\n    return result\n\n  def get_config(self):\n    \"\"\"Standard Keras config for serialization.\"\"\"\n    return {\n        \"l1\": self.l1,\n        \"l2\": self.l2,\n        \"is_cyclic\": self.is_cyclic,\n    }  # pyformat: disable\n"
  },
  {
    "path": "tensorflow_lattice/python/pwl_calibration_lib.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implementation of algorithms required for PWL calibration layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport copy\nimport enum\nfrom . import utils\nimport tensorflow as tf\n\n\nclass BoundConstraintsType(enum.Enum):\n  \"\"\"Type of bound constraints for PWL calibration.\n\n  - NONE: no constraints.\n  - BOUND: output range can be anywhere within bounds.\n  - CLAMPED: output range must exactly match bounds.\n  \"\"\"\n  NONE = 0\n  BOUND = 1\n  CLAMPED = 2\n\n\ndef convert_all_constraints(output_min, output_max, clamp_min, clamp_max):\n  \"\"\"Converts parameters of PWL calibration layer to internal format.\n\n  Args:\n    output_min: None for unconstrained bound or some numeric value.\n    output_max: None for unconstrained bound or some numeric value.\n    clamp_min: Whether to clamp pwl calibrator to value if `output_min` is not\n      None.\n    clamp_max: Whether to clamp pwl calibrator to value if `output_max` is not\n      None.\n\n  Returns:\n    \"value\" as float and appropriate value of\n    `tfl.pwl_calibration_lib.BoundConstraintsType` enum which corresponds to\n    `output_min(max)` and `clamp_min(max)`.\n  \"\"\"\n  if output_min is None:\n    output_max, output_max_constraints = _convert_constraints(\n        output_max, clamp_max)\n    output_min = output_max\n    output_min_constraints = BoundConstraintsType.NONE\n  elif output_max is None:\n    output_min, output_min_constraints = _convert_constraints(\n        output_min, clamp_min)\n    output_max = output_min\n    output_max_constraints = BoundConstraintsType.NONE\n  else:\n    output_min, output_min_constraints = _convert_constraints(\n        output_min, clamp_min)\n    output_max, output_max_constraints = _convert_constraints(\n        output_max, clamp_max)\n  return output_min, output_max, output_min_constraints, output_max_constraints\n\n\ndef _convert_constraints(value, clamp_to_value):\n  \"\"\"Converts constraints for output_min/max to internal format.\n\n  Args:\n    value: None for unconstrained bound or some numeric value.\n    clamp_to_value: Whether to clamp pwl calibrator to value if value isn't None\n\n  Returns:\n    \"value\" as float and appropriate value of\n    `tfl.pwl_calibration_lib.BoundConstraintsType` enum which\n    corresponds to `value` and `clamp_to_value`.\n  \"\"\"\n  if value is None:\n    return 0.0, BoundConstraintsType.NONE\n  else:\n    value = float(value)\n    if clamp_to_value:\n      return value, BoundConstraintsType.CLAMPED\n    else:\n      return value, BoundConstraintsType.BOUND\n\n\ndef compute_interpolation_weights(inputs, keypoints, lengths):\n  \"\"\"Computes weights for PWL calibration.\n\n  Args:\n    inputs: Tensor of shape: `(batch_size, 1)`, `(batch_size, units, 1)` or\n    `(batch_size, 1, 1)`. For multi-unit calibration, broadcasting will be used\n    if needed.\n    keypoints: Tensor of shape `(num_keypoints-1)` or `(units, num_keypoints-1)`\n      which represents left keypoint of pieces of piecewise linear function\n      along X axis.\n    lengths: Tensor of shape `(num_keypoints-1)` or `(units, num_keypoints-1)`\n      which represents lengths of pieces of piecewise linear function along X\n      axis.\n\n  Returns:\n    Interpolation weights tensor of shape: `(batch_size, num_keypoints)` or\n    `(batch_size, units, num_keypoints)`.\n  \"\"\"\n  weights = (inputs - keypoints) / lengths\n  weights = tf.minimum(weights, 1.0)\n  weights = tf.maximum(weights, 0.0)\n  # Prepend 1.0 at the beginning to add bias unconditionally. Worth testing\n  # different strategies, including those commented out, on different hardware.\n  if len(keypoints.shape) == 1:\n    return tf.concat([tf.ones_like(inputs), weights], axis=-1)\n  else:\n    shape = tf.concat([tf.shape(weights)[:-1], [1]], axis=0)\n    return tf.concat([tf.ones(shape), weights], axis=-1)\n  # return tf.concat([tf.ones_like(weights)[..., :1], weights], axis=-1)\n  # return tf.concat([tf.ones_like(weights[..., :1]), weights], axis=-1)\n  # paddings = [[0, 0]] * (len(weights.shape) - 1) + [[1, 0]]\n  # return tf.pad(weights, paddings, constant_values=1.)\n\n\ndef linear_initializer(shape,\n                       output_min,\n                       output_max,\n                       monotonicity,\n                       keypoints=None,\n                       dtype=None):\n  \"\"\"Initializes PWL calibration layer to represent linear function.\n\n  PWL calibration layer weights have shape `(num_keypoints, units)`. First row\n  represents bias. All remaining represent delta in y-value compare to previous\n  point. Aka heights of segments.\n\n  Args:\n    shape: Requested shape. Must be `(num_keypoints, units)`.\n    output_min: Minimum value of PWL calibration output after initialization.\n    output_max: Maximum value of PWL calibration output after initialization.\n    monotonicity: If one of {0, 1}, the returned function will go from\n      `(input_min, output_min)` to `(input_max, output_max)`. If set to -1, the\n      returned function will go from `(input_min, output_max)` to `(input_max,\n      output_min)`.\n    keypoints: If not provided (None or []), all pieces of returned function\n      will have equal heights (i.e. `y[i+1] - y[i]` is constant). If provided,\n      all pieces of returned function will have equal slopes (i.e. `(y[i+1] -\n      y[i]) / (x[i+1] - x[i])` is constant).\n    dtype: dtype.\n\n  Returns:\n    PWLCalibration layer weights initialized according to params.\n\n  Raises:\n    ValueError: If given parameters are inconsistent.\n  \"\"\"\n  verify_hyperparameters(\n      input_keypoints=keypoints,\n      output_min=output_min,\n      output_max=output_max,\n      monotonicity=monotonicity,\n      weights_shape=shape)\n\n  num_keypoints, units = int(shape[0]), int(shape[1])\n  if keypoints is None:\n    # Subtract 1 for bias which will be handled separately.\n    num_pieces = num_keypoints - 1\n    segment_height = (output_max - output_min) / num_pieces\n    heights_tensor = tf.constant(\n        [segment_height] * num_pieces, shape=[num_pieces, 1], dtype=dtype)\n  else:\n    keypoints_tensor = tf.constant(\n        keypoints, shape=[num_keypoints, 1], dtype=dtype)\n    lengths_tensor = keypoints_tensor[1:] - keypoints_tensor[0:-1]\n    output_range = output_max - output_min\n    heights_tensor = (\n        lengths_tensor * (output_range / tf.reduce_sum(lengths_tensor)))\n\n  if units > 1:\n    heights_tensor = tf.tile(heights_tensor, multiples=[1, units])\n\n  if monotonicity == -1:\n    bias = output_max\n    heights_tensor = -heights_tensor\n  else:\n    bias = output_min\n  bias_tensor = tf.constant(bias, shape=[1, units], dtype=dtype)\n\n  return tf.concat([bias_tensor, heights_tensor], axis=0)\n\n\ndef _approximately_project_bounds_only(bias, heights, output_min, output_max,\n                                       output_min_constraints,\n                                       output_max_constraints):\n  \"\"\"Bounds constraints implementation for PWL calibration layer.\n\n  Maps given weights of PWL calibration layer into some point which satisfies\n  given bounds by capping the function based on the bounds. This is not an exact\n  projection in L2 norm, but it is sufficiently accurate and efficient in\n  practice for non monotonic functions.\n\n  Args:\n    bias: `(1, units)`-shape tensor which represents bias.\n    heights: `(num_heights, units)`-shape tensor which represents heights.\n    output_min: Minimum possible output of pwl function.\n    output_max: Maximum possible output of pwl function.\n    output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's minimum value.\n    output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's maximum value.\n\n  Raises:\n    ValueError: If `output_min(max)_constraints` is set to \"CLAMPED\" which is\n      not supported.\n\n  Returns:\n    Projected bias and heights.\n  \"\"\"\n  if (output_min_constraints == BoundConstraintsType.CLAMPED or\n      output_max_constraints == BoundConstraintsType.CLAMPED):\n    raise ValueError(\"Clamping is not implemented for non monotonic functions.\")\n  if (output_min_constraints == BoundConstraintsType.NONE and\n      output_max_constraints == BoundConstraintsType.NONE):\n    return bias, heights\n\n  # Compute cumulative sums - they correspond to our calibrator outputs at\n  # keypoints. Simply clip them according to config and compute new heights\n  # using clipped cumulative sums.\n  sums = tf.cumsum(tf.concat([bias, heights], axis=0))\n  if output_min_constraints == BoundConstraintsType.BOUND:\n    sums = tf.maximum(sums, output_min)\n  if output_max_constraints == BoundConstraintsType.BOUND:\n    sums = tf.minimum(sums, output_max)\n\n  bias = sums[0:1]\n  heights = sums[1:] - sums[:-1]\n  return bias, heights\n\n\ndef _project_bounds_considering_monotonicity(bias, heights, monotonicity,\n                                             output_min, output_max,\n                                             output_min_constraints,\n                                             output_max_constraints):\n  \"\"\"Bounds projection given monotonicity constraints.\n\n  Projects weights of PWLCalibration layer into nearest in terms of l2 distance\n  point which satisfies bounds constraints taking into account that function\n  is monotonic.\n\n  Algorithm:\n  To minimize L2 distance to projected point we want to distribute update\n  through heights as evenly as possible. A simplified description of the\n  algorithm for and increasing function is as follows:\n  Consider only increasing function.\n\n  ```\n  delta = (output_max - (bias + sum(heights[:]))) / (num_heights + 1)\n  bias = max(bias + delta, output_min)\n  heights[:] += delta\n  ```\n\n  Some details which were omitted above:\n  * If `output_min_constraints == \"CAPPED\"` then `bias` variable becomes\n    constant (this means we can't add delta to it).\n  * if `output_max_constraints != \"CAPPED\"` we are looking only for negative\n    delta because we are not required to stretch function to meet upper bound.\n  * If function is decreasing we multiply everything by -1 and switch min and\n    max to make it increasing.\n\n  Args:\n    bias: `(1, units)`-shape tensor which represents bias.\n    heights: `(num_heights, units)`-shape tensor which represents heights.\n    monotonicity: 1 for increasing, -1 for decreasing.\n    output_min: Lower bound constraint of PWL calibration layer.\n    output_max: Upper bound constraint of PWL calibration layer.\n    output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's minimum value.\n    output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's maximum value.\n\n  Returns:\n    Projected bias and heights tensors.\n\n  Raises:\n    ValueError: If monotonicity is not in: {-1, 1}\n  \"\"\"\n  if monotonicity not in [-1, 1]:\n    raise ValueError(\"Monotonicity should be one of: [-1, 1]. It is: \" +\n                     str(monotonicity))\n  if monotonicity == -1:\n    # Reduce computation of projection of decreasing function to computation of\n    # projection of increasing function by multiplying everything by -1 and\n    # swapping maximums and minimums.\n    (projected_bias,\n     projected_heights) = _project_bounds_considering_monotonicity(\n         bias=-bias,\n         heights=-heights,\n         monotonicity=1,\n         output_min=None if output_max is None else -output_max,\n         output_max=None if output_min is None else -output_min,\n         output_min_constraints=output_max_constraints,\n         output_max_constraints=output_min_constraints)\n    return -projected_bias, -projected_heights\n\n  bct = BoundConstraintsType\n  if output_max_constraints != bct.NONE:\n    num_heights = float(heights.shape.dims[0].value)\n    sum_heights = tf.reduce_sum(heights, axis=0)\n\n    # For each possible output_min_constraints value compute projected bias and\n    # heights_delta.\n    if output_min_constraints == bct.CLAMPED:\n      # If output_min is clamped - bias must have fixed value and number of free\n      # parameters is equal to number of heights.\n      bias = tf.constant(output_min, shape=bias.shape, dtype=bias.dtype)\n      heights_delta = (output_max - (bias + sum_heights)) / num_heights\n    elif output_min_constraints == bct.BOUND:\n      # If output_min is not clamped then number of free parameters is\n      # num_heights + 1.\n      bias_delta = (output_max - (bias + sum_heights)) / (num_heights + 1)\n      if output_max_constraints != bct.CLAMPED:\n        # If output_max is not clamped - there is no need to stretch our\n        # function. We need only to squeeze it.\n        bias_delta = tf.minimum(bias_delta, 0.0)\n      bias = tf.maximum(bias + bias_delta, output_min)\n      # For this branch compute heights delta _after_ we applied bias projection\n      # because heights are not bound by output_min constraint unlike bias.\n      heights_delta = (output_max - (bias + sum_heights)) / num_heights\n    else:\n      bias_delta = (output_max - (bias + sum_heights)) / (num_heights + 1)\n      # For this branch heights delta and bias delta are same because none of\n      # them are bounded from below.\n      heights_delta = bias_delta\n      if output_max_constraints != bct.CLAMPED:\n        # If output_max is not clamped - there is no need to stretch our\n        # function. We need only to squeeze it.\n        bias_delta = tf.minimum(bias_delta, 0.0)\n      bias += bias_delta\n\n    if output_max_constraints != bct.CLAMPED:\n      # If output_max is not clamped - there is no need to stretch our function.\n      # We need only to squeeze it.\n      heights_delta = tf.minimum(heights_delta, 0.0)\n    heights += heights_delta\n  else:\n    # No need to do anything with heights if there are no output_max\n    # constraints.\n    if output_min_constraints == bct.CLAMPED:\n      bias = tf.constant(output_min, shape=bias.shape, dtype=bias.dtype)\n    elif output_min_constraints == bct.BOUND:\n      bias = tf.maximum(bias, output_min)\n\n  return bias, heights\n\n\ndef _project_convexity(heights, lengths, convexity, constraint_group):\n  \"\"\"Convexity projection for given 'constraint_group'.\n\n  Since an exact single step projection is not possible for convexity\n  constraints, we break the constraints into two independent groups and apply\n  Dykstra's alternating projections algorithm. Each group consists of a list of\n  pairs where each pair represents constraints on 2 consequtive heights.\n\n  Groups:\n\n  ```\n  g0 = [(h0, h1), (h2, h3), (h4, h5), ...]\n  g1 = [(h1, h2), (h3, h4), (h5, h6), ...]\n  ```\n\n  We know how to project single pair of adjacent heights:\n  h0_prime = min/max(h0, (l0 / (l0 + l1)) * (h0 + h1))\n  h1_prime = min/max(h1, (l1 / (l0 + l1)) * (h0 + h1))\n  where l0 and l1 stand for lengths of segment which correspond to h0 and h1 and\n  choise of min or max functions depends on convexity direction.\n\n  We can see that all pairs within same group are independent so we know how to\n  project such group of constraints in single pass.\n\n  This function breaks heights and their lengths into given constraint group\n  and does projection for this group.\n\n  Args:\n    heights: `(num_heights, units)`-shape tensor which represents heights.\n    lengths: `(num_heights)`-shape tensor which represents lengths of segments\n      which correspond to heights.\n    convexity: -1 or 1 where 1 stands for convex function and -1 for concave.\n    constraint_group: 0 or 1 which represent group from description above.\n\n  Returns:\n    Projected heights for given constraint group.\n  \"\"\"\n  verify_hyperparameters(\n      convexity=convexity,\n      lengths=lengths,\n      weights_shape=[heights.shape[0] + 1, heights.shape[1]])\n  if constraint_group not in [0, 1]:\n    raise ValueError(\"constraint_group must be one of: [0, 1]. \"\n                     \"Given: %s\" % constraint_group)\n\n  if convexity == 0 or heights.shape[0] == 1:\n    return heights\n\n  num_heights = heights.shape.dims[0].value\n  # To avoid broadcasting when performing math ops with 'heights'.\n  lengths = tf.reshape(lengths, shape=(-1, 1))\n\n  # Split heigths and lengths into pairs which correspond to given constraint\n  # group. In order to do this we need to split heights into odd and even. We\n  # can possibly omit last element of larger set to ensure that both sets have\n  # same number of elements.\n  num_0 = (num_heights - constraint_group + 1) // 2\n  num_1 = (num_heights - constraint_group) // 2\n  if num_1 == num_0:\n    last_index = None\n  else:\n    last_index = -1\n  heights_0 = heights[constraint_group:last_index:2]\n  lengths_0 = lengths[constraint_group:last_index:2]\n  heights_1 = heights[constraint_group + 1::2]\n  lengths_1 = lengths[constraint_group + 1::2]\n\n  # h0_prime = (l0 / (l0 + l1)) * (h0 + h1) = l0 * base\n  # h1_prime = (l1 / (l0 + l1)) * (h0 + h1) = l1 * base\n  base = (heights_0 + heights_1) / (lengths_0 + lengths_1)\n  heights_0_prime = lengths_0 * base\n  heights_1_prime = lengths_1 * base\n  if convexity == 1:\n    heights_0 = tf.minimum(heights_0, heights_0_prime)\n    heights_1 = tf.maximum(heights_1, heights_1_prime)\n  else:\n    heights_0 = tf.maximum(heights_0, heights_0_prime)\n    heights_1 = tf.minimum(heights_1, heights_1_prime)\n\n  # Now we need to merge heights in such way that elements from 'heights_0' and\n  # 'heights_1' alternate:\n  # merged = [heights_0[0], heights_1[0], heights_0[1], heights_1[1], ...]\n  # Achieve this by concatenating along axis=1 so after concatenation elements\n  # from 'heights_0' and 'heights_1' will alternate in memory and reshape will\n  # give us desired result.\n  projected_heights = tf.reshape(\n      tf.concat([heights_0, heights_1], axis=1), shape=[-1, heights.shape[1]])\n\n  weights_pieces = [projected_heights]\n  if constraint_group == 1:\n    # First height was skipped during initial split.\n    weights_pieces = [heights[0:1]] + weights_pieces\n  if last_index == -1:\n    # Last height was skipped during initial split.\n    weights_pieces.append(heights[-1:])\n\n  if len(weights_pieces) == 1:\n    return weights_pieces[0]\n  else:\n    return tf.concat(weights_pieces, axis=0)\n\n\ndef _project_monotonicity(heights, monotonicity):\n  \"\"\"Projects into monotonic function.\"\"\"\n  if monotonicity == 0:\n    return heights\n  elif monotonicity == 1:\n    return tf.maximum(heights, 0.0)\n  else:\n    return tf.minimum(heights, 0.0)\n\n\ndef project_all_constraints(weights,\n                            monotonicity,\n                            output_min,\n                            output_max,\n                            output_min_constraints,\n                            output_max_constraints,\n                            convexity,\n                            lengths,\n                            num_projection_iterations=8):\n  \"\"\"Jointly projects into all supported constraints.\n\n  For all combinations of constraints except the case where bounds constraints\n  are specified without monotonicity constraints we properly project into\n  nearest point with respect to L2 norm. For latter case we use a heuristic to\n  map input point into some feasible point with no guarantees on how close this\n  point is to the true projection.\n\n  If only bounds or only monotonicity constraints are specified there will be a\n  single step projection. For all other combinations of constraints we use\n  num_projection_iterations iterations of Dykstra's alternating projection\n  algorithm to jointly project onto all the given constraints. Dykstra's\n  algorithm gives us proper projection with respect to L2 norm but approaches it\n  from \"wrong\" side. That's why in order to ensure that constraints are strictly\n  met we'll do approximate projections in the end which project strictly into\n  feasible space, but it's not an exact projection with respect to the L2 norm.\n  With enough iterations of the Dykstra's algorithm, the impact of such\n  approximate projection should be negligible.\n\n  With bound and convexity constraints and no specified monotonicity, this\n  method does not fully satisfy the constrains. Increasing the number of\n  iterations can reduce the constraint violation in such cases.\n\n  Args:\n    weights: `(num_keypoints, units)`-shape tensor which represents weights of\n      PWL calibration layer.\n    monotonicity: 1 for increasing, -1 for decreasing, 0 for no monotonicity\n      constraints.\n    output_min: Lower bound constraint of PWL calibration layer.\n    output_max: Upper bound constraint of PWL calibration layer.\n    output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's minimum value.\n    output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's maximum value.\n    convexity: 1 for convex, -1 for concave, 0 for no convexity constraints.\n    lengths: Lengths of pieces of piecewise linear function. Needed only if\n      convexity projection is specified.\n    num_projection_iterations: Number of iterations of Dykstra's alternating\n      projection algorithm.\n\n  Returns:\n    Projected weights tensor.\n  \"\"\"\n  bias = weights[0:1]\n  heights = weights[1:]\n\n  def body(projection_counter, bias, heights, last_bias_change,\n           last_heights_change):\n    \"\"\"The body of tf.while_loop implementing a step of Dykstra's projection.\n\n    Args:\n      projection_counter: The counter tensor or number at the beginning of the\n        iteration.\n      bias: Bias tensor at the beginning of the iteration.\n      heights: Heights tensor at the beginning of the iteration.\n      last_bias_change: Dict that stores the last change in the bias after\n        projecting onto each subset of constraints.\n      last_heights_change: Dict that stores the last change in the heights after\n        projecting onto each subset of constraints.\n\n    Returns:\n      The tuple `(num_projection_counter, bias, heights, last_bias_change,\n      last_heights_change)` at the end of the iteration.\n    \"\"\"\n    last_bias_change = copy.copy(last_bias_change)\n    last_heights_change = copy.copy(last_heights_change)\n    num_projections = 0\n    # ******************** BOUNDS *********************\n    bct = BoundConstraintsType\n    if output_min_constraints != bct.NONE or output_max_constraints != bct.NONE:\n      rolled_back_bias = bias - last_bias_change[\"BOUNDS\"]\n      rolled_back_heights = heights - last_heights_change[\"BOUNDS\"]\n      if monotonicity != 0:\n        bias, heights = _project_bounds_considering_monotonicity(\n            bias=rolled_back_bias,\n            heights=rolled_back_heights,\n            monotonicity=monotonicity,\n            output_min=output_min,\n            output_max=output_max,\n            output_min_constraints=output_min_constraints,\n            output_max_constraints=output_max_constraints)\n      else:\n        bias, heights = _approximately_project_bounds_only(\n            bias=rolled_back_bias,\n            heights=rolled_back_heights,\n            output_min=output_min,\n            output_max=output_max,\n            output_min_constraints=output_min_constraints,\n            output_max_constraints=output_max_constraints)\n      last_bias_change[\"BOUNDS\"] = bias - rolled_back_bias\n      last_heights_change[\"BOUNDS\"] = heights - rolled_back_heights\n      num_projections += 1\n\n    # ******************** MONOTONICITY *********************\n    if monotonicity != 0:\n      rolled_back_heights = heights - last_heights_change[\"MONOTONICITY\"]\n      heights = _project_monotonicity(\n          heights=rolled_back_heights, monotonicity=monotonicity)\n      last_heights_change[\"MONOTONICITY\"] = heights - rolled_back_heights\n      num_projections += 1\n\n    # ******************** CONVEXITY *********************\n    if convexity != 0:\n      if heights.shape[0] >= 2:\n        rolled_back_heights = heights - last_heights_change[\"CONVEXITY_0\"]\n        heights = _project_convexity(\n            heights=rolled_back_heights,\n            lengths=lengths,\n            convexity=convexity,\n            constraint_group=0)\n        last_heights_change[\"CONVEXITY_0\"] = heights - rolled_back_heights\n        num_projections += 1\n      if heights.shape[0] >= 3:\n        rolled_back_heights = heights - last_heights_change[\"CONVEXITY_1\"]\n        heights = _project_convexity(\n            heights=rolled_back_heights,\n            lengths=lengths,\n            convexity=convexity,\n            constraint_group=1)\n        last_heights_change[\"CONVEXITY_1\"] = heights - rolled_back_heights\n        num_projections += 1\n\n    return (projection_counter + num_projections, bias, heights,\n            last_bias_change, last_heights_change)\n\n  # Call the body of the loop once to see if Dykstra's is needed.\n  # If there is only one set of projections, apply it without a loop.\n  # Running the body of the loop also finds the required last_bias_change\n  # and last_heights_change keys. The set of keys in the input and output of the\n  # body of tf.while_loop must be the same across iterations.\n  zero_bias = tf.zeros_like(bias)\n  zero_heights = tf.zeros_like(heights)\n  last_bias_change = collections.defaultdict(lambda: zero_bias)\n  last_heights_change = collections.defaultdict(lambda: zero_heights)\n  (num_projections, projected_bias, projected_heights, last_bias_change,\n   last_heights_change) = body(0, bias, heights, last_bias_change,\n                               last_heights_change)\n  if num_projections <= 1:\n    return tf.concat([projected_bias, projected_heights], axis=0)\n\n  def cond(projection_counter, bias, heights, last_bias_change,\n           last_heights_change):\n    del bias, heights, last_bias_change, last_heights_change\n    return tf.less(projection_counter,\n                   num_projection_iterations * num_projections)\n\n  # Apply Dykstra's algorithm with tf.while_loop.\n  projection_counter = tf.constant(0)\n  last_bias_change = {k: zero_bias for k in last_bias_change}\n  last_heights_change = {k: zero_heights for k in last_heights_change}\n  (_, bias, heights, _,\n   _) = tf.while_loop(cond, body, (projection_counter, bias, heights,\n                                   last_bias_change, last_heights_change))\n\n  # Since Dykstra's algorithm is iterative in order to strictly meet constraints\n  # we use approximate projection algorithm to finalize them.\n  return _finalize_constraints(\n      bias=bias,\n      heights=heights,\n      monotonicity=monotonicity,\n      output_min=output_min,\n      output_max=output_max,\n      output_min_constraints=output_min_constraints,\n      output_max_constraints=output_max_constraints,\n      convexity=convexity,\n      lengths=lengths)\n\n\ndef _squeeze_by_scaling(bias, heights, monotonicity, output_min, output_max,\n                        output_min_constraints, output_max_constraints):\n  \"\"\"Squeezes monotonic calibrators by scaling in order to meet bounds.\n\n  Projection by scaling is not exact with respect to the L2 norm, but maintains\n  convexity unlike projection by shift.\n\n  Args:\n    bias: `(1, units)`-shape tensor which represents bias.\n    heights: `(num_heights, units)`-shape tensor which represents heights.\n    monotonicity: 1 for increasing, -1 for decreasing.\n    output_min: Lower bound constraint of PWL calibration layer.\n    output_max: Upper bound constraint of PWL calibration layer.\n    output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's minimum value.\n    output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's maximum value.\n\n  Returns:\n    Projected bias and heights.\n  \"\"\"\n  if monotonicity == -1:\n    if output_min_constraints == BoundConstraintsType.NONE:\n      return bias, heights\n    # Reduce computation of projection of decreasing function to computation of\n    # projection of increasing function by multiplying everything by -1 and\n    # swapping maximums and minimums.\n    bias, heights = _squeeze_by_scaling(\n        bias=-bias,\n        heights=-heights,\n        monotonicity=1,\n        output_min=None if output_max is None else -output_max,\n        output_max=None if output_min is None else -output_min,\n        output_min_constraints=output_max_constraints,\n        output_max_constraints=output_min_constraints)\n    return -bias, -heights\n  if output_max_constraints == BoundConstraintsType.NONE:\n    return bias, heights\n\n  delta = output_max - bias\n  # For better stability use tf.where rather than the more standard approach:\n  # heights *= tf.reduce_sum(heights) / max(delta, eps)\n  # in order to keep everything strictly unchanged for small deltas, rather than\n  # increase heights by factor 1/eps and still don't meet constraints.\n  scaling_factor = tf.where(delta > 0.001,\n                            tf.reduce_sum(heights, axis=0) / delta,\n                            tf.ones_like(delta))\n  heights = heights / tf.maximum(scaling_factor, 1.0)\n  return bias, heights\n\n\ndef _approximately_project_convexity(heights, lengths, convexity):\n  \"\"\"Strictly projects convexity, but is not exact with respect to the L2 norm.\n\n  Projects by iterating over pieces of piecewise linear function left to right\n  and aligning current slope with previous one if it violates convexity.\n\n  Args:\n    heights: `(num_heights, units)`-shape tensor which represents heights.\n    lengths: `(num_heights)`-shape tensor which represents lengths of segments\n      which correspond to heights.\n    convexity: -1 or 1 where 1 stands for convex function and -1 for concave.\n\n  Returns:\n    Projected heights.\n  \"\"\"\n  if convexity == 0:\n    return heights\n  heights = tf.unstack(heights, axis=0)\n  lengths = tf.unstack(lengths, axis=0)\n  for i in range(1, len(heights)):\n    temp = heights[i - 1] * (lengths[i] / lengths[i - 1])\n    if convexity == 1:\n      heights[i] = tf.maximum(heights[i], temp)\n    else:\n      heights[i] = tf.minimum(heights[i], temp)\n\n  return tf.stack(heights, axis=0)\n\n\ndef _finalize_constraints(bias, heights, monotonicity, output_min, output_max,\n                          output_min_constraints, output_max_constraints,\n                          convexity, lengths):\n  \"\"\"Strictly projects onto the given constraint, approximate w.r.t the L2 norm.\n\n  Dykstra's algorithm gives us proper projection with respect to L2 norm but\n  approaches it from \"wrong\" side. In order to ensure that constraints are\n  strictly met we'll do approximate projections in the end which project\n  strictly into feasible space, but it's not an exact projection with respect to\n  the L2 norm. With enough iterations of the Dykstra's algorithm, the impact of\n  such approximate projection should be negligible.\n\n  With bound and convexity constraints and no specified monotonicity, this\n  method does not fully satisfy the constrains. Increasing the number of\n  iterations can reduce the constraint violation in such cases. Fortunately it\n  does not seem to be common config.\n\n  Args:\n    bias: `(1, units)`-shape tensor which represents bias.\n    heights: `(num_heights, units)`-shape tensor which represents heights.\n    monotonicity: 1 for increasing, -1 for decreasing, 0 for no monotonicity\n      constraints.\n    output_min: Lower bound constraint of PWL calibration layer.\n    output_max: Upper bound constraint of PWL calibration layer.\n    output_min_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's minimum value.\n    output_max_constraints: A `tfl.pwl_calibration_lib.BoundConstraintsType`\n      describing the constraints on the layer's maximum value.\n    convexity: 1 for convex, -1 for concave, 0 for no convexity constraints.\n    lengths: Lengths of pieces of piecewise linear function. Needed only if\n      convexity projection is specified.\n\n  Returns:\n    Projected weights tensor.\n  \"\"\"\n  # Convexity and monotonicity projections don't violate each other, but both\n  # might lead to bounds violation, so do them first and fix bounds after.\n  if monotonicity != 0:\n    heights = _project_monotonicity(heights=heights, monotonicity=monotonicity)\n  if convexity != 0:\n    heights = _approximately_project_convexity(\n        heights=heights, lengths=lengths, convexity=convexity)\n\n  bct = BoundConstraintsType\n  if output_min_constraints != bct.NONE or output_max_constraints != bct.NONE:\n    if monotonicity != 0 and convexity != 0:\n      # Both monotonicity and convexity projection can only increase upper bound\n      # so we only need to take care of decreasing it back.\n      bias, heights = _squeeze_by_scaling(\n          bias=bias,\n          heights=heights,\n          monotonicity=monotonicity,\n          output_min=output_min,\n          output_max=output_max,\n          output_min_constraints=output_min_constraints,\n          output_max_constraints=output_max_constraints)\n    else:\n      # This bounds projection might violate convexity. Unfortunately bounds\n      # projections with convexity and without monotonicity are are difficult to\n      # achieve strictly and might be violated. so ignore this for now. In order\n      # to minimize projection error consider increasing\n      # num_projection_iterations.\n      if output_min_constraints == bct.CLAMPED:\n        output_min_constraints = bct.BOUND\n      if output_max_constraints == bct.CLAMPED:\n        output_max_constraints = bct.BOUND\n      bias, heights = _approximately_project_bounds_only(\n          bias=bias,\n          heights=heights,\n          output_min=output_min,\n          output_max=output_max,\n          output_min_constraints=output_min_constraints,\n          output_max_constraints=output_max_constraints)\n  return tf.concat([bias, heights], axis=0)\n\n\ndef assert_constraints(outputs,\n                       monotonicity,\n                       output_min,\n                       output_max,\n                       clamp_min=False,\n                       clamp_max=False,\n                       debug_tensors=None,\n                       eps=1e-6):\n  \"\"\"Asserts that 'outputs' satisfiy constraints.\n\n  Args:\n    outputs: Tensor of shape `(num_output_values, units)` which represents\n      outputs of pwl calibration layer which will be tested against the given\n      constraints. If monotonicity is specified these outputs must be for\n      consequtive inputs.\n    monotonicity: One of {-1, 0, 1}. -1 for decreasing, 1 for increasing 0 means\n      no monotonicity checks.\n    output_min: Lower bound or None.\n    output_max: Upper bound or None.\n    clamp_min: Whether one of outputs must match output_min.\n    clamp_max: Whther one of outputs must match output_max.\n    debug_tensors: None or list of anything convertible to tensor (for example\n      tensors or strings) which will be printed in case of constraints\n      violation.\n    eps: Allowed constraints violation.\n\n  Raises:\n    ValueError: If monotonicity is not one of {-1, 0, 1}\n\n  Returns:\n    List of assertion ops in graph mode or immideately asserts in eager mode.\n  \"\"\"\n\n  info = [\"Outputs: \", outputs, \"Epsilon: \", eps]\n  if debug_tensors:\n    info += debug_tensors\n  asserts = []\n\n  if output_min is not None:\n    min_output = tf.reduce_min(outputs, axis=0)\n    if clamp_min:\n      asserts.append(\n          tf.Assert(\n              tf.reduce_all(tf.abs(min_output - output_min) <= eps),\n              data=[\"Clamp_min violation.\", \"output_min:\", output_min] + info,\n              summarize=outputs.shape[0]))\n    else:\n      asserts.append(\n          tf.Assert(\n              tf.reduce_all(min_output >= output_min - eps),\n              data=[\"Lower bound violation.\", \"output_min:\", output_min] + info,\n              summarize=outputs.shape[0]))\n\n  if output_max is not None:\n    max_output = tf.reduce_max(outputs, axis=0)\n    if clamp_max:\n      asserts.append(\n          tf.Assert(\n              tf.reduce_all(tf.abs(max_output - output_max) <= eps),\n              data=[\"Clamp_max violation.\", \"output_max:\", output_max] + info,\n              summarize=outputs.shape[0]))\n    else:\n      asserts.append(\n          tf.Assert(\n              tf.reduce_all(max_output <= output_max + eps),\n              data=[\"Upper bound violation.\", \"output_max:\", output_max] + info,\n              summarize=outputs.shape[0]))\n\n  if monotonicity not in [-1, 0, 1]:\n    raise ValueError(\"'monotonicity' must be one of: [-1, 0, 1]. It is: %s\" %\n                     monotonicity)\n  if monotonicity != 0:\n    diffs = (outputs[1:] - outputs[0:-1])\n    asserts.append(\n        tf.Assert(\n            tf.reduce_min(diffs * monotonicity) >= -eps,\n            data=[\"Monotonicity violation.\", \"monotonicity:\", monotonicity] +\n            info,\n            summarize=outputs.shape[0]))\n\n  return asserts\n\n\ndef verify_hyperparameters(input_keypoints=None,\n                           output_min=None,\n                           output_max=None,\n                           monotonicity=None,\n                           convexity=None,\n                           is_cyclic=False,\n                           lengths=None,\n                           weights_shape=None,\n                           input_keypoints_type=None):\n  \"\"\"Verifies that all given hyperparameters are consistent.\n\n  See PWLCalibration class level comment for detailed description of arguments.\n\n  Args:\n    input_keypoints: `input_keypoints` of PWLCalibration layer.\n    output_min: Smallest output of PWLCalibration layer.\n    output_max: Largest output of PWLCalibration layer.\n    monotonicity: `monotonicity` hyperparameter of PWLCalibration layer.\n    convexity: `convexity` hyperparameter of PWLCalibration layer.\n    is_cyclic: `is_cyclic` hyperparameter of PWLCalibration layer.\n    lengths: Lengths of pieces of piecewise linear function.\n    weights_shape: Shape of weights of PWLCalibration layer.\n    input_keypoints_type: The type of input keypoints of a PWLCalibration layer.\n\n  Raises:\n    ValueError: If something is inconsistent.\n  \"\"\"\n  if input_keypoints is not None:\n    if tf.is_tensor(input_keypoints):\n      if len(input_keypoints.shape) != 1 or input_keypoints.shape[0] < 2:\n        raise ValueError(\"Input keypoints must be rank-1 tensor of size at \"\n                         \"least 2. It is: \" + str(input_keypoints))\n    else:\n      if len(input_keypoints) < 2:\n        raise ValueError(\"At least 2 input keypoints must be provided. \"\n                         \"Given: \" + str(input_keypoints))\n      if not all(input_keypoints[i] < input_keypoints[i + 1]\n                 for i in range(len(input_keypoints) - 1)):\n        raise ValueError(\"Keypoints must be strictly increasing. They are: \" +\n                         str(input_keypoints))\n\n  if output_min is not None and output_max is not None:\n    if output_max < output_min:\n      raise ValueError(\"If specified output_max must be greater than \"\n                       \"output_min. \"\n                       \"They are: ({}, {})\".format(output_min, output_max))\n\n  # It also raises errors if monotonicities specified incorrectly.\n  monotonicity = utils.canonicalize_monotonicity(monotonicity)\n  convexity = utils.canonicalize_convexity(convexity)\n\n  if is_cyclic and (monotonicity or convexity):\n    raise ValueError(\"'is_cyclic' can not be specified together with \"\n                     \"'monotonicity'({}) or 'convexity'({}).\".format(\n                         monotonicity, convexity))\n\n  if weights_shape is not None:\n    if len(weights_shape) != 2 or weights_shape[0] < 2:\n      raise ValueError(\"PWLCalibrator weights must have shape: [k, units] where\"\n                       \" k > 1. It is: \" + str(weights_shape))\n\n  if lengths is not None and weights_shape is not None:\n    if tf.is_tensor(lengths):\n      num_lengths = lengths.shape[0]\n    else:\n      num_lengths = len(lengths)\n    if num_lengths + 1 != weights_shape[0]:\n      raise ValueError(\"Number of lengths must be equal to number of weights \"\n                       \"minus one. Lengths: %s, weights_shape: %s\" %\n                       (lengths, weights_shape))\n\n  if (input_keypoints_type is not None and input_keypoints_type != \"fixed\" and\n      input_keypoints_type != \"learned_interior\"):\n    raise ValueError(\n        \"input_keypoints_type must be one of 'fixed' or 'learned_interior': %s\"\n        % input_keypoints_type)\n"
  },
  {
    "path": "tensorflow_lattice/python/pwl_calibration_test.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for PWL calibration layer.\n\nThis test should be run with \"-c opt\" since otherwise it's slow.\nAlso, to only run a subset of the tests (useful when developing a new test or\nset of tests), change the initialization of the _disable_all boolean to 'True'\nin the SetUp method, and comment out the check for this boolean in those tests\nthat you want to run.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n\nfrom absl import logging\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import parallel_combination_layer as parallel_combination\nfrom tensorflow_lattice.python import pwl_calibration_layer as keras_layer\nfrom tensorflow_lattice.python import test_utils\nfrom tensorflow_lattice.python import utils\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass CalibrateWithSeparateMissing(keras.layers.Layer):\n  \"\"\"Create separate is_missing tensor.\n\n  Splits input tensor into list: [input_tensor, is_missing_tensor] and passes\n  this list as input to given calibration layer.\n  \"\"\"\n\n  def __init__(self, calibration_layer, missing_input_value):\n    super(CalibrateWithSeparateMissing, self).__init__()\n    self.calibration_layer = calibration_layer\n    self.missing_input_value = missing_input_value\n\n  def call(self, x):\n    is_missing = tf.cast(\n        tf.equal(x, self.missing_input_value), dtype=tf.float32)\n    return self.calibration_layer([x, is_missing])\n\n\nclass PwlCalibrationLayerTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(PwlCalibrationLayerTest, self).setUp()\n    self._disable_all = False\n    self._loss_eps = 0.0001\n    self._small_eps = 1e-6\n    keras.utils.set_random_seed(42)\n\n  def _ResetAllBackends(self):\n    keras.backend.clear_session()\n    tf.compat.v1.reset_default_graph()\n\n  def _ScatterXUniformly(self, units, num_points, input_min, input_max,\n                         missing_probability, missing_input_value):\n    \"\"\"Randomly uniformly scatters points across input space.\"\"\"\n    np.random.seed(41)\n    x = [\n        input_min + np.random.random(units) * (input_max - input_min)\n        for _ in range(num_points)\n    ]\n    if missing_probability > 0.0:\n      is_missings = np.random.random([num_points, units]) < missing_probability\n      x = [\n          is_missing * missing_input_value + (1. - is_missing) * point\n          for point, is_missing in zip(x, is_missings)\n      ]\n    x.sort(key=np.sum)\n    return x\n\n  def _ScatterXUniformlyIncludeBounds(self, units, **kwargs):\n    \"\"\"Same as _ScatterXUniformly() but includes bounds.\"\"\"\n    x = self._ScatterXUniformly(units, **kwargs)\n    x[0] = np.array([kwargs[\"input_min\"]] * units)\n    x[-1] = np.array([kwargs[\"input_max\"]] * units)\n    return x\n\n  def _SmallWaves(self, x):\n    return np.mean(\n        np.power(x, 3) + 0.1 * np.sin(x * math.pi * 8), keepdims=True)\n\n  def _SmallWavesPlusOne(self, x):\n    return self._SmallWaves(x) + 1.0\n\n  def _WavyParabola(self, x):\n    return np.mean(\n        np.power(x, 2) + 0.1 * np.sin(x * math.pi * 8) - 0.5, keepdims=True)\n\n  def _SinCycle(self, x):\n    # Almost entire cycle of sin.\n    return np.mean(np.sin(x / 26.0 * (2.0 * math.pi)), keepdims=True)\n\n  def _GenPWLFunction(self, input_keypoints, pwl_weights):\n    \"\"\"Returns python function equivalent to PWL calibration layer.\n\n    Output of returned function is equivalent ot output of PWL calibration layer\n    with keypoints being 'input_keypoints' and learned weights being\n    'pwl_weights'.\n\n    Args:\n      input_keypoints: list of keypoints of PWL calibration layer.\n      pwl_weights: list of weights of PWL calibration layer.\n    \"\"\"\n\n    def Pwl(x):\n      result = pwl_weights[0]\n      for begin, end, weight in zip(input_keypoints[0:-1], input_keypoints[1:],\n                                    pwl_weights[1:]):\n        result += weight * np.maximum(\n            np.minimum((x - begin) / (end - begin), 1.0), 0.0)\n      return np.mean(result, keepdims=True)\n\n    return Pwl\n\n  def _SetDefaults(self, config):\n    config.setdefault(\"units\", 1)\n    config.setdefault(\"use_multi_calibration_layer\", False)\n    config.setdefault(\"one_d_input\", False)\n    config.setdefault(\"use_separate_missing\", False)\n    config.setdefault(\"output_min\", None)\n    config.setdefault(\"output_max\", None)\n    config.setdefault(\"missing_input_value\", None)\n    config.setdefault(\"missing_output_value\", None)\n    config.setdefault(\"monotonicity\", 0)\n    config.setdefault(\"convexity\", 0)\n    config.setdefault(\"is_cyclic\", False)\n    config.setdefault(\"clamp_min\", False)\n    config.setdefault(\"clamp_max\", False)\n    config.setdefault(\"initializer\", \"equal_heights\")\n    config.setdefault(\"kernel_regularizer\", None)\n    config.setdefault(\"impute_missing\", False)\n    config.setdefault(\"missing_probability\", 0.0)\n    config.setdefault(\"num_projection_iterations\", 8)\n    config.setdefault(\"constraint_assertion_eps\", 1e-6)\n    config.setdefault(\"model_dir\", \"/tmp/test_pwl_model_dir/\")\n    config.setdefault(\"dtype\", tf.float32)\n    config.setdefault(\"input_keypoints_type\", \"fixed\")\n\n    if \"input_keypoints\" not in config:\n      # If \"input_keypoints\" are provided - other params referred by code below\n      # might be not available, so we make sure it exists before executing\n      # this code.\n      config.setdefault(\n          \"input_keypoints\",\n          np.linspace(\n              start=config[\"input_min\"],\n              stop=config[\"input_max\"],\n              num=config[\"num_keypoints\"]))\n    return config\n\n  def _TrainModel(self, config):\n    \"\"\"Trains model and returns loss.\n\n    Args:\n      config: Layer config internal for this test which specifies params of\n        piecewise linear layer to train.\n\n    Returns:\n      Training loss.\n    \"\"\"\n    logging.info(\"Testing config:\")\n    logging.info(config)\n    config = self._SetDefaults(config)\n\n    self._ResetAllBackends()\n\n    # The input to the model can either be single or multi dimensional.\n    input_units = 1 if config[\"one_d_input\"] else config[\"units\"]\n\n    training_inputs = config[\"x_generator\"](\n        units=input_units,\n        num_points=config[\"num_training_records\"],\n        input_min=config[\"input_keypoints\"][0],\n        input_max=config[\"input_keypoints\"][-1],\n        missing_probability=config[\"missing_probability\"],\n        missing_input_value=config[\"missing_input_value\"])\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n\n    # Either create multiple PWLCalibration layers and combine using a\n    # ParallelCombination layer, or create a single PWLCalibration with multiple\n    # output dimensions.\n    if config[\"use_multi_calibration_layer\"]:\n      num_calibration_layers = config[\"units\"]\n      pwl_calibration_units = 1\n    else:\n      num_calibration_layers = 1\n      pwl_calibration_units = config[\"units\"]\n\n    model = keras.models.Sequential()\n    model.add(keras.layers.Input(shape=[input_units], dtype=tf.float32))\n    calibration_layers = []\n    for _ in range(num_calibration_layers):\n      calibration_layers.append(\n          keras_layer.PWLCalibration(\n              units=pwl_calibration_units,\n              dtype=tf.float32,\n              input_keypoints=config[\"input_keypoints\"],\n              output_min=config[\"output_min\"],\n              output_max=config[\"output_max\"],\n              clamp_min=config[\"clamp_min\"],\n              clamp_max=config[\"clamp_max\"],\n              monotonicity=config[\"monotonicity\"],\n              convexity=config[\"convexity\"],\n              is_cyclic=config[\"is_cyclic\"],\n              kernel_initializer=config[\"initializer\"],\n              kernel_regularizer=config[\"kernel_regularizer\"],\n              impute_missing=config[\"impute_missing\"],\n              missing_output_value=config[\"missing_output_value\"],\n              missing_input_value=config[\"missing_input_value\"],\n              num_projection_iterations=config[\"num_projection_iterations\"],\n              input_keypoints_type=config[\"input_keypoints_type\"]))\n    if len(calibration_layers) == 1:\n      if config[\"use_separate_missing\"]:\n        model.add(\n            CalibrateWithSeparateMissing(\n                calibration_layer=calibration_layers[0],\n                missing_input_value=config[\"missing_input_value\"]))\n      else:\n        model.add(calibration_layers[0])\n    else:\n      model.add(parallel_combination.ParallelCombination(calibration_layers))\n\n    if config[\"units\"] > 1:\n      model.add(\n          keras.layers.Lambda(\n              lambda x: tf.reduce_mean(x, axis=1, keepdims=True)))\n\n    model.compile(\n        loss=keras.losses.mean_squared_error,\n        optimizer=config[\"optimizer\"](learning_rate=config[\"learning_rate\"]))\n\n    training_data = (training_inputs, training_labels)\n\n    loss = test_utils.run_training_loop(\n        config=config, training_data=training_data, keras_model=model\n    )\n\n    assetion_ops = []\n    for calibration_layer in calibration_layers:\n      assetion_ops.extend(\n          calibration_layer.assert_constraints(\n              eps=config[\"constraint_assertion_eps\"]))\n    if not tf.executing_eagerly() and assetion_ops:\n      tf.compat.v1.keras.backend.get_session().run(assetion_ops)\n\n    return loss\n\n  def _InverseAndTrain(self, config):\n    \"\"\"Changes monotonicity directions to opposite and trains model.\"\"\"\n    inversed_config = dict(config)\n    inversed_config[\"y_function\"] = lambda x: -config[\"y_function\"](x)\n\n    inversed_config[\"output_max\"] = config[\"output_min\"]\n    if inversed_config[\"output_max\"] is not None:\n      inversed_config[\"output_max\"] = inversed_config[\"output_max\"] * -1.0\n\n    inversed_config[\"output_min\"] = config[\"output_max\"]\n    if inversed_config[\"output_min\"] is not None:\n      inversed_config[\"output_min\"] = inversed_config[\"output_min\"] * -1.0\n\n    inversed_config[\"clamp_min\"] = config[\"clamp_max\"]\n    inversed_config[\"clamp_max\"] = config[\"clamp_min\"]\n    inversed_config[\"monotonicity\"] = -utils.canonicalize_monotonicity(\n        config[\"monotonicity\"])\n    inversed_config[\"convexity\"] = -utils.canonicalize_convexity(\n        config[\"convexity\"])\n    inversed_loss = self._TrainModel(inversed_config)\n    return inversed_loss\n\n  def _CreateTrainingData(self, config):\n    training_inputs = config[\"x_generator\"](\n        units=config[\"units\"],\n        num_points=config[\"num_training_records\"],\n        input_min=config[\"input_keypoints\"][0],\n        input_max=config[\"input_keypoints\"][-1],\n        missing_probability=config[\"missing_probability\"],\n        missing_input_value=config[\"missing_input_value\"])\n    training_labels = [config[\"y_function\"](x) for x in training_inputs]\n    training_inputs = tf.convert_to_tensor(training_inputs, dtype=tf.float32)\n    training_labels = tf.convert_to_tensor(training_labels, dtype=tf.float32)\n    return (training_inputs, training_labels)\n\n  def _CreateKerasLayer(self, config):\n    missing_input_value = config[\"missing_input_value\"]\n    if config[\"use_separate_missing\"]:\n      # We use 'config[\"missing_input_value\"]' to create the is_missing tensor,\n      # and we want the model to use the is_missing tensor so we don't pass\n      # a missing_input_value to the model.\n      missing_input_value = None\n    return keras_layer.PWLCalibration(\n        input_keypoints=config[\"input_keypoints\"],\n        units=config[\"units\"],\n        output_min=config[\"output_min\"],\n        output_max=config[\"output_max\"],\n        clamp_min=config[\"clamp_min\"],\n        clamp_max=config[\"clamp_max\"],\n        monotonicity=config[\"monotonicity\"],\n        convexity=config[\"convexity\"],\n        is_cyclic=config[\"is_cyclic\"],\n        kernel_initializer=config[\"initializer\"],\n        kernel_regularizer=config[\"kernel_regularizer\"],\n        impute_missing=config[\"impute_missing\"],\n        missing_output_value=config[\"missing_output_value\"],\n        missing_input_value=missing_input_value,\n        num_projection_iterations=config[\"num_projection_iterations\"],\n        dtype=config[\"dtype\"])\n\n  @parameterized.parameters(\n      (1, False, 0.001022, \"fixed\"),\n      (3, False, 0.000543, \"fixed\"),\n      (3, True, 0.000987, \"fixed\"),\n      (1, False, 0.000393, \"learned_interior\"),\n      (3, False, 0.000427, \"learned_interior\"),\n      (3, True, 0.000577, \"learned_interior\"),\n  )\n  def testUnconstrainedNoMissingValue(self, units, one_d_input, expected_loss,\n                                      input_keypoints_type):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"one_d_input\": one_d_input,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 2000,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 0,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": None,\n        \"output_max\": None,\n        \"input_keypoints_type\": input_keypoints_type,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1 and not one_d_input:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, None, 0.000858),\n      (1, 0.5, 0.637769),\n      (3, None, 0.000471),\n      (3, 0.5, 0.190513),\n  )\n  def testUnconstrainedWithMissingValue(self, units, missing_output_value,\n                                        expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 2000,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 0,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": None,\n        \"output_max\": None,\n        \"impute_missing\": True,\n        \"missing_input_value\": -1.2,\n        \"missing_output_value\": missing_output_value,\n        \"missing_probability\": 0.1,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    config[\"use_separate_missing\"] = True\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, -1.5, 1.5, keras.optimizers.SGD, 2100, 0.002957),\n      (1, -1.5, 1.5, keras.optimizers.Adagrad, 2100, 0.002798),\n      # TODO: Something really weird is going on here with Adam\n      # optimizer in case when num_training_epoch is exactly 2010.\n      # Test verifies result with 2100 epochs which behaves as expected.\n      (1, -1.5, 1.5, keras.optimizers.Adam, 2100, 0.000769),\n      (1, -0.5, 0.5, keras.optimizers.SGD, 200, 0.011483),\n      (1, -0.5, 0.5, keras.optimizers.Adagrad, 200, 0.011645),\n      (1, -0.5, 0.5, keras.optimizers.Adam, 200, 0.011116),\n      (3, -1.5, 1.5, keras.optimizers.Adagrad, 2100, 0.001759),\n      (3, -0.5, 0.5, keras.optimizers.Adagrad, 200, 0.005986),\n  )\n  def testNonMonotonicFunction(self, units, output_min, output_max, optimizer,\n                               num_training_epoch, expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 2100,\n        \"optimizer\": keras.optimizers.SGD,\n        \"learning_rate\": 0.015,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 0,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": -1.5,\n        \"output_max\": 1.5,\n        \"clamp_min\": False,\n        \"clamp_max\": False,\n    }\n    config[\"output_min\"] = output_min\n    config[\"output_max\"] = output_max\n    config[\"optimizer\"] = optimizer\n    config[\"num_training_epoch\"] = num_training_epoch\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, -1.5, 0.287357),\n      (1, 1.5, 0.287357),\n      (3, -1.5, 0.122801),\n      (3, 1.5, 0.106150),\n  )\n  # Since function is symmetric result should be same for both values above.\n  def testBoundsForMissing(self, units, missing_input_value, expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 1,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": -2.0,\n        \"output_max\": 2.0,\n        \"clamp_min\": False,\n        \"clamp_max\": True,\n        \"impute_missing\": True,\n        \"missing_probability\": 0.1,\n    }\n    config[\"missing_input_value\"] = missing_input_value\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, None, None, 0.002505),\n      (1, None, 1.21, 0.008076),\n      (1, None, 1.6, 0.000251),\n      (1, None, 2.0, 0.001107),\n      (1, 0.5, None, 0.000790),\n      (1, 0.5, 1.21, 0.008353),\n      (1, 0.5, 1.6, 0.000685),\n      (1, 0.5, 2.0, 0.000694),\n      (1, 0.9, None, 0.000143),\n      (1, 0.9, 1.21, 0.008108),\n      (1, 0.9, 1.6, 0.000125),\n      (1, 0.9, 2.0, 0.000120),\n      (1, 1.2, None, 0.025762),\n      (1, 1.2, 1.21, 0.026069),\n      (1, 1.2, 1.6, 0.025240),\n      (1, 1.2, 2.0, 0.024802),\n      (3, None, None, 0.003268),\n      (3, None, 1.21, 0.003901),\n      (3, None, 1.6, 0.000897),\n      (3, None, 2.0, 0.002608),\n      (3, 0.5, None, 0.000945),\n      (3, 0.5, 1.21, 0.004830),\n      (3, 0.5, 1.6, 0.000945),\n      (3, 0.5, 2.0, 0.000923),\n      (3, 0.9, None, 0.000318),\n      (3, 0.9, 1.21, 0.004215),\n      (3, 0.9, 1.6, 0.000335),\n      (3, 0.9, 2.0, 0.000297),\n      (3, 1.2, None, 0.011354),\n      (3, 1.2, 1.21, 0.011354),\n      (3, 1.2, 1.6, 0.011354),\n      (3, 1.2, 2.0, 0.011354),\n  )\n  def testAllBoundsWithoutMonotonicityConstraints(self, units, output_min,\n                                                  output_max, expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWavesPlusOne,\n        \"monotonicity\": 0,\n        \"num_keypoints\": 21,\n        \"input_min\": 0.1,\n        \"input_max\": 0.8,\n        \"clamp_min\": False,\n        \"clamp_max\": False,\n    }\n    config[\"output_min\"] = output_min\n    config[\"output_max\"] = output_max\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, False, keras.optimizers.SGD, 0.004715),\n      (1, False, keras.optimizers.Adagrad, 0.003820),\n      (1, False, keras.optimizers.Adam, 0.002797),\n      (1, True, keras.optimizers.SGD, 0.004427),\n      (1, True, keras.optimizers.Adagrad, 0.004084),\n      # Adam is doing terrible when required to stretch monotonic function\n      # even if bounds are proper.\n      (1, True, keras.optimizers.Adam, 0.065664),\n      (3, False, keras.optimizers.Adagrad, 0.002371),\n      (3, True, keras.optimizers.Adagrad, 0.002670),\n  )\n  def testMonotonicProperBounds(self, units, is_clamped, optimizer,\n                                expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 400,\n        \"optimizer\": optimizer,\n        \"learning_rate\": 0.015,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": \"increasing\",\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": -1.0,\n        \"output_max\": 1.0,\n        \"clamp_min\": is_clamped,\n        \"clamp_max\": is_clamped,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, False, keras.optimizers.SGD, 0.15, 0.009563),\n      (1, False, keras.optimizers.Adagrad, 0.015, 0.011117),\n      (1, False, keras.optimizers.Adam, 0.015, 0.015356),\n      (1, True, keras.optimizers.SGD, 0.15, 0.009563),\n      (1, True, keras.optimizers.Adagrad, 0.015, 0.011117),\n      # Adam squeezes monotonic function just slightly worse than adagrad.\n      (1, True, keras.optimizers.Adam, 0.015, 0.015189),\n      (3, False, keras.optimizers.Adagrad, 0.015, 0.006057),\n      (3, True, keras.optimizers.Adagrad, 0.015, 0.006049),\n  )\n  def testMonotonicNarrowBounds(self, units, is_clamped, optimizer,\n                                learning_rate, expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": optimizer,\n        \"learning_rate\": learning_rate,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 1,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": -0.5,\n        \"output_max\": 0.5,\n        \"clamp_min\": is_clamped,\n        \"clamp_max\": is_clamped,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, False, keras.optimizers.SGD, 0.005920),\n      (1, False, keras.optimizers.Adagrad, 0.006080),\n      (1, False, keras.optimizers.Adam, 0.002914),\n      (1, True, keras.optimizers.SGD, 0.013836),\n      (1, True, keras.optimizers.Adagrad, 0.066928),\n      # Adam is doing terrible when required to stretch monotonic function.\n      (1, True, keras.optimizers.Adam, 0.230402),\n      (3, False, keras.optimizers.Adagrad, 0.004891),\n      (3, True, keras.optimizers.Adagrad, 0.021490),\n  )\n  def testMonotonicWideBounds(self, units, is_clamped, optimizer,\n                              expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 400,\n        \"optimizer\": optimizer,\n        \"learning_rate\": 0.015,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 1,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": -1.5,\n        \"output_max\": 1.5,\n        \"clamp_min\": is_clamped,\n        \"clamp_max\": is_clamped,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, None, None, False, False, 0.003744),\n      (1, None, None, False, True, 0.003744),\n      (1, None, 1.6, True, False, 0.001456),\n      (1, None, 1.6, True, True, 0.001465),\n      (1, None, 2.0, False, False, 0.001712),\n      (1, None, 2.0, False, True, 0.01623),\n      (1, None, 2.0, True, False, 0.001712),\n      (1, None, 2.0, True, True, 0.01623),\n      (1, 0.5, None, False, False, 0.002031),\n      (1, 0.5, None, False, True, 0.002031),\n      (1, 0.5, None, True, False, 0.003621),\n      (1, 0.5, None, True, True, 0.003621),\n      (1, None, None, True, False, 0.003744),\n      (1, 0.5, 1.21, False, False, 0.007572),\n      (1, 0.5, 1.21, False, True, 0.007572),\n      (1, 0.5, 1.21, True, False, 0.009876),\n      (1, 0.5, 1.21, True, True, 0.009876),\n      (1, 0.5, 1.6, False, False, 0.001916),\n      (1, 0.5, 1.6, False, True, 0.001737),\n      (1, 0.5, 1.6, True, False, 0.003103),\n      (1, 0.5, 1.6, True, True, 0.002692),\n      (1, 0.5, 2.0, False, False, 0.001873),\n      (1, 0.5, 2.0, False, True, 0.003333),\n      (1, None, None, True, True, 0.003744),\n      (1, 0.5, 2.0, True, False, 0.003315),\n      (1, 0.5, 2.0, True, True, 0.004289),\n      (1, 0.9, None, False, False, 0.00151),\n      (1, 0.9, None, False, True, 0.00151),\n      (1, 0.9, None, True, False, 0.001552),\n      (1, 0.9, None, True, True, 0.001552),\n      (1, 0.9, 1.21, False, False, 0.005387),\n      (1, 0.9, 1.21, False, True, 0.005387),\n      (1, 0.9, 1.21, True, False, 0.005427),\n      (1, 0.9, 1.21, True, True, 0.005427),\n      (1, None, 1.21, False, False, 0.005366),\n      (1, 0.9, 1.6, False, False, 0.0015),\n      (1, 0.9, 1.6, False, True, 0.001454),\n      (1, 0.9, 1.6, True, False, 0.001546),\n      (1, 0.9, 1.6, True, True, 0.001514),\n      (1, 0.9, 2.0, False, False, 0.001501),\n      (1, 0.9, 2.0, False, True, 0.003067),\n      (1, 0.9, 2.0, True, False, 0.001547),\n      (1, 0.9, 2.0, True, True, 0.00312),\n      (1, 1.2, None, False, False, 0.021835),\n      (1, 1.2, None, False, True, 0.021835),\n      (1, None, 1.21, False, True, 0.005366),\n      (1, 1.2, None, True, False, 0.021835),\n      (1, 1.2, None, True, True, 0.021835),\n      (1, 1.2, 1.21, False, False, 0.025733),\n      (1, 1.2, 1.21, False, True, 0.025733),\n      (1, 1.2, 1.21, True, False, 0.025733),\n      (1, 1.2, 1.21, True, True, 0.025733),\n      (1, 1.2, 1.6, False, False, 0.021834),\n      (1, 1.2, 1.6, False, True, 0.021967),\n      (1, 1.2, 1.6, True, False, 0.021834),\n      (1, 1.2, 1.6, True, True, 0.021967),\n      (1, None, 1.21, True, False, 0.005366),\n      (1, 1.2, 2.0, False, False, 0.021834),\n      (1, 1.2, 2.0, False, True, 0.023642),\n      (1, 1.2, 2.0, True, False, 0.021834),\n      (1, 1.2, 2.0, True, True, 0.023642),\n      (1, None, 1.21, True, True, 0.005366),\n      (1, None, 1.6, False, False, 0.001456),\n      (1, None, 1.6, False, True, 0.001465),\n      (3, None, None, False, False, 0.003969),\n      (3, None, None, False, True, 0.003969),\n      (3, 0.5, None, True, False, 0.003125),\n      (3, 0.5, None, True, True, 0.003125),\n      (3, None, None, True, False, 0.003969),\n      (3, 0.5, 1.21, False, False, 0.003676),\n      (3, 0.5, 1.21, False, True, 0.003676),\n      (3, 0.5, 1.21, True, False, 0.006550),\n      (3, 0.5, 1.21, True, True, 0.006550),\n      (3, 0.5, 1.6, False, False, 0.001246),\n      (3, 0.5, 1.6, False, True, 0.001000),\n      (3, 0.5, 1.6, True, False, 0.002775),\n      (3, None, 1.6, True, False, 0.000662),\n      (3, 0.5, 1.6, True, True, 0.002720),\n      (3, 0.5, 2.0, False, False, 0.001272),\n      (3, 0.5, 2.0, False, True, 0.001779),\n      (3, None, None, True, True, 0.003969),\n      (3, 0.5, 2.0, True, False, 0.002852),\n      (3, 0.5, 2.0, True, True, 0.003496),\n      (3, 0.9, None, False, False, 0.000597),\n      (3, 0.9, None, False, True, 0.000597),\n      (3, 0.9, None, True, False, 0.000678),\n      (3, 0.9, None, True, True, 0.000678),\n      (3, None, 1.6, True, True, 0.000640),\n      (3, 0.9, 1.21, False, False, 0.002630),\n      (3, 0.9, 1.21, False, True, 0.002630),\n      (3, 0.9, 1.21, True, False, 0.002906),\n      (3, 0.9, 1.21, True, True, 0.002906),\n      (3, None, 1.21, False, False, 0.002565),\n      (3, 0.9, 1.6, False, False, 0.000575),\n      (3, 0.9, 1.6, False, True, 0.000520),\n      (3, 0.9, 1.6, True, False, 0.000648),\n      (3, 0.9, 1.6, True, True, 0.000606),\n      (3, 0.9, 2.0, False, False, 0.000556),\n      (3, None, 2.0, False, False, 0.000901),\n      (3, 0.9, 2.0, False, True, 0.001230),\n      (3, 0.9, 2.0, True, False, 0.000636),\n      (3, 0.9, 2.0, True, True, 0.001314),\n      (3, 1.2, None, False, False, 0.010638),\n      (3, 1.2, None, False, True, 0.010638),\n      (3, None, 1.21, False, True, 0.002565),\n      (3, 1.2, None, True, False, 0.010638),\n      (3, 1.2, None, True, True, 0.010638),\n      (3, 1.2, 1.21, False, False, 0.011300),\n      (3, 1.2, 1.21, False, True, 0.011309),\n      (3, None, 2.0, False, True, 0.003166),\n      (3, 1.2, 1.21, True, False, 0.011300),\n      (3, 1.2, 1.21, True, True, 0.011309),\n      (3, 1.2, 1.6, False, False, 0.010631),\n      (3, 1.2, 1.6, False, True, 0.012681),\n      (3, 1.2, 1.6, True, False, 0.010631),\n      (3, 1.2, 1.6, True, True, 0.012681),\n      (3, None, 1.21, True, False, 0.002565),\n      (3, 1.2, 2.0, False, False, 0.010627),\n      (3, 1.2, 2.0, False, True, 0.016435),\n      (3, 1.2, 2.0, True, False, 0.010627),\n      (3, None, 2.0, True, False, 0.000901),\n      (3, 1.2, 2.0, True, True, 0.016435),\n      (3, None, 1.21, True, True, 0.002565),\n      (3, None, 1.6, False, False, 0.000662),\n      (3, None, 1.6, False, True, 0.000640),\n      (3, None, 2.0, True, True, 0.003166),\n      (3, 0.5, None, False, False, 0.001334),\n      (3, 0.5, None, False, True, 0.001334),\n  )\n  def testAllBoundsAndMonotonicityDirection(self, units, output_min, output_max,\n                                            clamp_min, clamp_max,\n                                            expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWavesPlusOne,\n        \"monotonicity\": 1,\n        \"num_keypoints\": 21,\n        \"input_min\": 0.1,\n        \"input_max\": 0.8,\n        \"output_min\": output_min,\n        \"output_max\": output_max,\n        \"clamp_min\": clamp_min,\n        \"clamp_max\": clamp_max,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    self.assertAlmostEqual(\n        loss, self._InverseAndTrain(config), delta=self._small_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n      self.assertAlmostEqual(\n          loss, self._InverseAndTrain(config), delta=self._small_eps)\n\n  @parameterized.parameters(\n      (1, 1, 0.018919),\n      (1, -1, 0.019434),\n      (3, \"convex\", 0.008592),\n      (3, \"concave\", 0.01134),\n  )\n  def testConvexitySimple(self, units, convexity, expected_loss):\n    # No constraints other than convexity.\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 120,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": \"none\",\n        \"convexity\": convexity,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": None,\n        \"output_max\": None,\n        \"num_projection_iterations\": 18,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, 1, 0.006286),\n      (1, -1, 0.078076),\n      (3, 1, 0.002941),\n      (3, -1, 0.032497),\n  )\n  def testConvexityNonUniformKeypoints(self, units, convexity, expected_loss):\n    # No constraints other than convexity.\n    if self._disable_all:\n      return\n\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 1.0,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WavyParabola,\n        \"monotonicity\": 0,\n        \"convexity\": convexity,\n        \"input_keypoints\": [-1.0, -0.9, -0.3, -0.2, 0.0, 0.3, 0.31, 0.35, 1.0],\n        \"output_min\": None,\n        \"output_max\": None,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, 2, 0.033706),\n      (1, 3, 0.006485),\n      (1, 4, 0.005128),\n      (1, 5, 0.004878),\n      (1, 6, 0.005083),\n      (1, 7, 0.004860),\n      (3, 2, 0.013585),\n      (3, 3, 0.003311),\n      (3, 4, 0.002633),\n      (3, 5, 0.001909),\n      (3, 6, 0.001822),\n      (3, 7, 0.001599),\n  )\n  def testConvexityDifferentNumKeypoints(self, units, num_keypoints,\n                                         expected_loss):\n    # No constraints other than convexity.\n    if self._disable_all:\n      return\n\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 120,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.3,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WavyParabola,\n        \"monotonicity\": 0,\n        \"convexity\": 1,\n        \"num_keypoints\": num_keypoints,\n        \"input_min\": -0.8,\n        \"input_max\": 0.8,\n        \"output_min\": None,\n        \"output_max\": None,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, \"increasing\", None, 0.055837),\n      (1, \"decreasing\", None, 0.046657),\n      (1, \"none\", 0.0, 0.027777),\n      (1, \"increasing\", 0.0, 0.065516),\n      (1, \"decreasing\", 0.0, 0.057453),\n      (3, \"increasing\", None, 0.022467),\n      (3, \"decreasing\", None, 0.019012),\n      (3, \"none\", 0.0, 0.014693),\n      (3, \"increasing\", 0.0, 0.026284),\n      (3, \"decreasing\", 0.0, 0.025498),\n  )\n  def testConvexityWithMonotonicityAndBounds(self, units, monotonicity,\n                                             output_max, expected_loss):\n    if self._disable_all:\n      return\n\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 120,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.5,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._WavyParabola,\n        \"monotonicity\": monotonicity,\n        \"convexity\": 1,\n        \"num_keypoints\": 21,\n        \"input_min\": -1.0,\n        \"input_max\": 1.0,\n        \"output_min\": None,\n        \"output_max\": output_max,\n        \"num_projection_iterations\": 8,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    self.assertAlmostEqual(\n        loss, self._InverseAndTrain(config), delta=self._small_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n      self.assertAlmostEqual(\n          loss, self._InverseAndTrain(config), delta=self._small_eps)\n\n  @parameterized.parameters(\n      ([-1.0, -0.8, 0.0, 0.2, 0.8, 1.0],),\n      (np.array([-1.0, -0.8, 0.0, 0.2, 0.8, 1.0]),),\n  )\n  def testInputKeypoints(self, keypoints):\n    if self._disable_all:\n      return\n    config = {\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 200,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 0,\n        \"input_keypoints\": keypoints,\n        \"output_min\": None,\n        \"output_max\": None,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.009650, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, None, 600, 0.002058),\n      (1, (\"laplacian\", 0.01, 0.0), 420, 0.040492),\n      (1, (\"hessian\", 0.01, 0.01), 300, 0.040932),\n      (1, (\"wrinkle\", 0.01, 0.01), 300, 0.027430),\n      (3, None, 600, 0.002150),\n      (3, (\"laplacian\", 0.01, 0.0), 420, 0.096667),\n      (3, (\"hessian\", 0.01, 0.01), 300, 0.092306),\n      (3, (\"wrinkle\", 0.01, 0.01), 300, 0.064053),\n  )\n  def testIsCyclic(self, units, regularizer, num_training_epoch, expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": num_training_epoch,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformlyIncludeBounds,\n        \"y_function\": self._SinCycle,\n        \"monotonicity\": 0,\n        \"input_min\": 0.0,\n        \"input_max\": 24.0,\n        \"num_keypoints\": 10,\n        \"is_cyclic\": True,\n        \"kernel_regularizer\": regularizer,\n        \"output_min\": None,\n        \"output_max\": None,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  @parameterized.parameters(\n      (1, \"equal_heights\", 0.332572),\n      (1, \"equal_slopes\", 0.476452),\n      (3, \"equal_heights\", 0.271896),\n      (3, \"equal_slopes\", 0.356754),\n  )\n  def testInitializer(self, units, initializer, expected_loss):\n    if self._disable_all:\n      return\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        # 0 training epochs to see pure output of initializer.\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 0,\n        \"input_keypoints\": [-1.0, -0.8, 0.0, 0.2, 0.8, 1.0],\n        \"output_min\": -1.0,\n        \"output_max\": 2.0,\n        \"initializer\": initializer,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)\n\n  # TODO: this test is only using the first piece of the PWL.\n  @parameterized.parameters(\n      (1, (\"laplacian\", 0.01, 0.001), 0.091, 0.089631),\n      (1, (\"Hessian\", 0.01, 0.001), 0.035, 0.033504),\n      (1, (\"wrinkle\", 0.01, 0.001), 0.011, 0.007018),\n      # Standard Keras regularizer:\n      (1, keras.regularizers.l1_l2(l1=0.01, l2=0.001), 0.091, 0.089906),\n      # List of regularizers:\n      (1, [(\"Hessian\", 0.01, 0.001),\n           keras.regularizers.l1_l2(l1=0.01, l2=0.001)], 0.126, 0.122192),\n      (3, (\"laplacian\", 0.01, 0.001), 0.273, 0.263244),\n      (3, (\"Hessian\", 0.01, 0.001), 0.105, 0.097368),\n      (3, (\"wrinkle\", 0.01, 0.001), 0.033, 0.013650),\n      # Standard Keras regularizer:\n      (3, keras.regularizers.l1_l2(l1=0.01, l2=0.001), 0.273, 0.265924),\n      # List of regularizers:\n      (3, [(\"Hessian\", 0.01, 0.001),\n           keras.regularizers.l1_l2(l1=0.01, l2=0.001)], 0.378, 0.354917),\n  )\n  def testRegularizers(self, units, regularizer, pure_reg_loss, training_loss):\n    if self._disable_all:\n      return\n    keypoints = [0.0, 1.0, 2.0, 3.0]\n    pwl_weights = [0.0, 1.0, 2.0, 4.0]\n    multi_pwl_weights = [[w] * units for w in pwl_weights]\n    # Keypoint outputs which correspond to weights: [0.0, 1.0, 3.0, 7.0]\n    config = {\n        \"units\": units,\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"input_keypoints\": keypoints,\n        \"y_function\": self._GenPWLFunction(keypoints, multi_pwl_weights),\n        # Initializer exactly matches target function.\n        \"initializer\":\n            lambda shape, dtype: tf.constant(multi_pwl_weights, shape=shape),\n        \"kernel_regularizer\": regularizer,\n    }  # pyformat: disable\n    loss = self._TrainModel(config)\n    # This loss is pure regularization loss because initializer matches target\n    # function and there was 0 training epochs.\n    self.assertAlmostEqual(loss, pure_reg_loss, delta=self._loss_eps)\n\n    config[\"num_training_epoch\"] = 20\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, training_loss, delta=self._loss_eps)\n    if units > 1:\n      config[\"use_multi_calibration_layer\"] = True\n      config[\"initializer\"] = (\n          lambda shape, dtype: tf.constant(pwl_weights, shape=shape))\n      loss = self._TrainModel(config)\n      self.assertAlmostEqual(loss, training_loss, delta=self._loss_eps)\n\n  def testAssertMonotonicity(self):\n    if self._disable_all:\n      return\n    decreasing_initializer = keras_layer.UniformOutputInitializer(\n        output_min=0.0, output_max=1.0, monotonicity=-1)\n    # Specify decreasing initializer and do 0 training iterations so no\n    # projections are being executed.\n    config = {\n        \"num_training_records\": 100,\n        \"num_training_epoch\": 0,\n        \"optimizer\": keras.optimizers.Adagrad,\n        \"learning_rate\": 0.15,\n        \"x_generator\": self._ScatterXUniformly,\n        \"y_function\": self._SmallWaves,\n        \"monotonicity\": 0,\n        \"num_keypoints\": 21,\n        \"input_min\": 0.0,\n        \"input_max\": 1.0,\n        \"output_min\": 0.0,\n        \"output_max\": 1.0,\n        \"initializer\": decreasing_initializer,\n    }\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.347888, delta=self._loss_eps)\n\n    # We have decreasing initializer so with 0 trainig steps monotonicity is\n    # violated.\n    with self.assertRaises(tf.errors.InvalidArgumentError):\n      config[\"monotonicity\"] = 1\n      loss = self._TrainModel(config)\n\n    # Now set upper bound bigger than necessary. Everything should be fine...\n    config[\"monotonicity\"] = 0\n    config[\"output_max\"] = 1.5\n    loss = self._TrainModel(config)\n    self.assertAlmostEqual(loss, 0.347888, delta=self._loss_eps)\n\n    # ... until we require to clamp max.\n    with self.assertRaises(tf.errors.InvalidArgumentError):\n      config[\"clamp_max\"] = True\n      loss = self._TrainModel(config)\n\n  def testOutputShape(self):\n    if self._disable_all:\n      return\n\n    # Not Splitting\n    units = 10\n    input_keypoints = [1, 2, 3, 4, 5]\n    input_shape, output_shape = (units,), (None, units)\n    input_a = keras.layers.Input(shape=input_shape)\n    pwl_0 = keras_layer.PWLCalibration(\n        input_keypoints=input_keypoints, units=units)\n    output = pwl_0(input_a)\n    self.assertAllEqual(output_shape, pwl_0.compute_output_shape(input_a.shape))\n    self.assertAllEqual(output_shape, output.shape)\n\n    # Splitting\n    output_shape = [(None, 1)] * units\n    pwl_1 = keras_layer.PWLCalibration(\n        input_keypoints=input_keypoints, units=units, split_outputs=True)\n    output = pwl_1(input_a)\n    self.assertAllEqual(output_shape, pwl_1.compute_output_shape(input_a.shape))\n    self.assertAllEqual(output_shape, [o.shape for o in output])\n\n  @parameterized.parameters((\"fixed\", 1, 1), (\"fixed\", 1, 2), (\"fixed\", 2, 2),\n                            (\"learned_interior\", 1, 1),\n                            (\"learned_interior\", 1, 2),\n                            (\"learned_interior\", 2, 2))\n  def testKeypointsInputs(self, input_keypoints_type, input_dims, output_units):\n    if self._disable_all:\n      return\n\n    input_keypoints = [0, 0.5, 1]\n    expected_function_output = np.array([[0.0] * output_units,\n                                         [0.5] * output_units,\n                                         [1.0] * output_units])\n\n    # Check after layer build\n    pwl = keras_layer.PWLCalibration(\n        input_keypoints=input_keypoints,\n        units=output_units,\n        input_keypoints_type=input_keypoints_type)\n    pwl.build(input_shape=[10, input_dims])\n    self.assertAllEqual(expected_function_output, pwl.keypoints_inputs())\n\n    # Check after Keras model compile\n    model = keras.models.Sequential()\n    model.add(keras.layers.Input(shape=[input_dims], dtype=tf.float32))\n    model.add(pwl)\n    model.compile(loss=keras.losses.mean_squared_error)\n    self.assertAllEqual(expected_function_output, pwl.keypoints_inputs())\n\n    # Check after Keras model fit; look for change in learned case.\n    train_x = np.random.uniform(size=(10, input_dims))\n    train_y = train_x[:, 0]**2\n    model.fit(train_x, train_y, batch_size=len(train_x), epochs=5, verbose=0)\n    if input_keypoints_type == \"fixed\":\n      self.assertAllEqual(expected_function_output, pwl.keypoints_inputs())\n    else:\n      self.assertNotAllEqual(expected_function_output, pwl.keypoints_inputs())\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/rtl_layer.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Layer which represents an ensemble of Random Tiny Lattices (RTL).\n\nSee class level comment.\n\nThis layer can take multiple inputs and use them in an ensemble of lattices.\nThe output can be set to be monotonic with respect to a subset of features. This\nlayer can output either a single dense tensor, or can have separate monotonic\nand unconstrained outputs to be fed into another RTL layer.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport itertools\n\nfrom absl import logging\nimport numpy as np\nimport six\nimport tensorflow as tf\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\nfrom . import kronecker_factored_lattice_layer as kfll\nfrom . import lattice_layer\nfrom . import rtl_lib\n\n_MAX_RTL_SWAPS = 10000\n_RTLInput = collections.namedtuple('_RTLInput',\n                                   ['monotonicity', 'group', 'input_index'])\nRTL_KFL_NAME = 'rtl_kronecker_factored_lattice'\nRTL_LATTICE_NAME = 'rtl_lattice'\nINPUTS_FOR_UNITS_PREFIX = 'inputs_for_lattice'\nRTL_CONCAT_NAME = 'rtl_concat'\n\n\nclass RTL(keras.layers.Layer):\n  # pyformat: disable\n  \"\"\"Layer which includes a random ensemble of lattices.\n\n  RTL (Random Tiny Lattices) is an ensemble of `tfl.layers.Lattice` layers that\n  takes in a collection of monotonic and unconstrained features and randomly\n  arranges them into lattices of a given rank. The input is taken as \"groups\",\n  and inputs from the same group will not be used in the same lattice. E.g. the\n  input can be the output of a calibration layer with multiple units applied to\n  the same input feature. If there are more slots in the RTL than the number of\n  inputs, inputs will be repeatedly used. Repeats will be approximately uniform\n  across all inputs.\n\n  Input shape:\n  One of:\n    - A dict with keys in `['unconstrained', 'increasing']`, and the values\n      either a list of tensors of shape (batch_size, D_i), or a single tensor\n      of shape (batch_size, D) that will be conceptually split into a list of D\n      tensors of size (batch_size, 1). Each tensor in the list is considered a\n      \"group\" of features that the RTL layer should try not to use in the same\n      lattice.\n    - A single tensor of shape (batch_size, D), which is considered to be\n      unconstrained and will be conceptually split into a list of D tensors of\n      size (batch_size, 1).\n\n  Output shape:\n  If `separate_outputs == True`, the output will be in the same format as the\n  input and can be passed to follow on RTL layers:\n  `{'unconstrained': unconstrained_out, 'increasing': mon_out}` where\n  `unconstrained_out` and `mon_out` are of (batch_size, num_unconstrained_out)\n  and (batch_size, num_mon_out) respectively, and\n  `num_unconstrained_out + num_mon_out == num_lattices`. If\n  `separate_outputs == False` the output will be a rank-2 tensor with shape:\n  (batch_size, num_lattices) if average_outputs is False, or (batch_size, 1) if\n  True.\n\n  Attributes:\n    - All `__init__ `arguments.\n\n  Example:\n\n  ```python\n  a = keras.Input(shape=(1,))\n  b = keras.Input(shape=(1,))\n  c = keras.Input(shape=(1,))\n  d = keras.Input(shape=(1,))\n  cal_a = tfl.layers.CategoricalCalibration(\n      units=10, output_min=0, output_max=1, ...)(a)\n  cal_b = tfl.layers.PWLCalibration(\n      units=20, output_min=0, output_max=1, ...)(b)\n  cal_c = tfl.layers.PWLCalibration(\n      units=10, output_min=0, output_max=1, monotonicity='increasing', ...)(c)\n  cal_d = tfl.layers.PWLCalibration(\n      units=20, output_min=0, output_max=1, monotonicity='decreasing', ...)(d)\n  rtl_0 = RTL(\n      num_lattices=20,\n      lattice_rank=3,\n      output_min=0,\n      output_max=1,\n      separate_outputs=True,\n  )({\n      'unconstrained': [cal_a, cal_b],\n      'increasing': [cal_c, cal_d],\n  })\n  rtl_1 = RTL(num_lattices=5, lattice_rank=4)(rtl_0)\n  outputs = tfl.layers.Linear(\n      num_input_dims=5,\n      monotonicities=['increasing'] * 5,\n  )(rtl_1)\n  model = keras.Model(inputs=[a, b, c, d], outputs=outputs)\n  ```\n  \"\"\"\n  # pyformat: enable\n\n  def __init__(self,\n               num_lattices,\n               lattice_rank,\n               lattice_size=2,\n               output_min=None,\n               output_max=None,\n               init_min=None,\n               init_max=None,\n               separate_outputs=False,\n               random_seed=42,\n               num_projection_iterations=10,\n               monotonic_at_every_step=True,\n               clip_inputs=True,\n               interpolation='hypercube',\n               parameterization='all_vertices',\n               num_terms=2,\n               avoid_intragroup_interaction=True,\n               kernel_initializer='random_monotonic_initializer',\n               kernel_regularizer=None,\n               average_outputs=False,\n               **kwargs):\n    # pyformat: disable\n    \"\"\"Initializes an instance of `RTL`.\n\n    Args:\n      num_lattices: Number of lattices in the ensemble.\n      lattice_rank: Number of features used in each lattice.\n      lattice_size: Number of lattice vertices per dimension (minimum is 2).\n      output_min: None or lower bound of the output.\n      output_max: None or upper bound of the output.\n      init_min: None or lower bound of lattice kernel initialization.\n      init_max: None or upper bound of lattice kernel initialization.\n      separate_outputs: If set to true, the output will be a dict in the same\n        format as the input to the layer, ready to be passed to another RTL\n        layer. If false, the output will be a single tensor of shape\n        (batch_size, num_lattices). See output shape for details.\n      random_seed: Random seed for the randomized feature arrangement in the\n        ensemble. Also used for initialization of lattices using\n        `'kronecker_factored'` parameterization.\n      num_projection_iterations: Number of iterations of Dykstra projections\n        algorithm. Projection updates will be closer to a true projection (with\n        respect to the L2 norm) with higher number of iterations. Increasing\n        this number has diminishing return on projection precsion. Infinite\n        number of iterations would yield perfect projection. Increasing this\n        number might slightly improve convergence by cost of slightly increasing\n        running time. Most likely you want this number to be proportional to\n        number of lattice vertices in largest constrained dimension.\n      monotonic_at_every_step: Whether to strictly enforce monotonicity and\n        trust constraints after every gradient update by applying a final\n        imprecise projection. Setting this parameter to True together with small\n        num_projection_iterations parameter is likely to hurt convergence.\n      clip_inputs: If inputs should be clipped to the input range of the\n        lattice.\n      interpolation: One of 'hypercube' or 'simplex' interpolation. For a\n        d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas\n        'simplex' uses d+1 parameters and thus scales better. For details see\n        `tfl.lattice_lib.evaluate_with_simplex_interpolation` and\n        `tfl.lattice_lib.evaluate_with_hypercube_interpolation`.\n      parameterization: The parameterization of the lattice function class to\n        use. A lattice function is uniquely determined by specifying its value\n        on every lattice vertex. A parameterization scheme is a mapping from a\n        vector of parameters to a multidimensional array of lattice vertex\n        values. It can be one of:\n          - String `'all_vertices'`: This is the \"traditional\" parameterization\n            that keeps one scalar parameter per lattice vertex where the mapping\n            is essentially the identity map. With this scheme, the number of\n            parameters scales exponentially with the number of inputs to the\n            lattice. The underlying lattices used will be `tfl.layers.Lattice`\n            layers.\n          - String `'kronecker_factored'`: With this parameterization, for each\n            lattice input i we keep a collection of `num_terms` vectors each\n            having `feature_configs[0].lattice_size` entries (note that the\n            lattice size of the first feature will be used as the lattice size\n            for all other features as well). To obtain the tensor of lattice\n            vertex values, for `t=1,2,...,num_terms` we compute the outer\n            product of the `t'th` vector in each collection, multiply by a\n            per-term scale, and sum the resulting tensors. Finally, we add a\n            single shared bias parameter to each entry in the sum. With this\n            scheme, the number of parameters grows linearly with `lattice_rank`\n            (assuming lattice sizes and `num_terms` are held constant).\n            Currently, only monotonicity shape constraint and bound constraint\n            are supported for this scheme. Regularization is not currently\n            supported. The underlying lattices used will be\n            `tfl.layers.KroneckerFactoredLattice` layers.\n      num_terms: The number of terms in a lattice using `'kronecker_factored'`\n        parameterization. Ignored if parameterization is set to\n        `'all_vertices'`.\n      avoid_intragroup_interaction: If set to true, the RTL algorithm will try\n        to avoid having inputs from the same group in the same lattice.\n      kernel_initializer: One of:\n        - `'linear_initializer'`: initialize parameters to form a linear\n          function with positive and equal coefficients for monotonic dimensions\n          and 0.0 coefficients for other dimensions. Linear function is such\n          that minimum possible output is equal to output_min and maximum\n          possible output is equal to output_max. See\n          `tfl.lattice_layer.LinearInitializer` class docstring for more\n          details. This initialization is not supported when using the\n          `'kronecker_factored'` parameterization.\n        - `'random_monotonic_initializer'`: initialize parameters uniformly at\n          random such that all parameters are monotonically increasing for each\n          input. Parameters will be sampled uniformly at random from the range\n          `[init_min, init_max]` if specified, otherwise\n          `[output_min, output_max]`. See\n          `tfl.lattice_layer.RandomMonotonicInitializer` class docstring for\n          more details. This initialization is not supported when using the\n          `'kronecker_factored'` parameterization.\n        - `'kfl_random_monotonic_initializer'`: initialize parameters uniformly\n          at random such that all parameters are monotonically increasing for\n          each monotonic input. Parameters will be sampled uniformly at random\n          from the range `[init_min, init_max]` if specified. Otherwise, the\n          initialization range will be algorithmically determined depending on\n          output_{min/max}. See `tfl.layers.KroneckerFactoredLattice` and\n          `tfl.kronecker_factored_lattice.KFLRandomMonotonicInitializer` class\n          docstrings for more details. This initialization is not supported when\n          using `'all_vertices'` parameterization.\n      kernel_regularizer: None or a single element or a list of following:\n        - Tuple `('torsion', l1, l2)` or List `['torsion', l1, l2]` where l1 and\n          l2 represent corresponding regularization amount for graph Torsion\n          regularizer. l1 and l2 must be single floats. Lists of floats to\n          specify different regularization amount for every dimension is not\n          currently supported.\n        - Tuple `('laplacian', l1, l2)` or List `['laplacian', l1, l2]` where l1\n          and l2 represent corresponding regularization amount for graph\n          Laplacian regularizer. l1 and l2 must be single floats. Lists of\n          floats to specify different regularization amount for every dimension\n          is not currently supported.\n      average_outputs: Whether to average the outputs of this layer. Ignored\n        when separate_outputs is True.\n      **kwargs: Other args passed to `keras.layers.Layer` initializer.\n\n    Raises:\n      ValueError: If layer hyperparameters are invalid.\n      ValueError: If `parameterization` is not one of `'all_vertices'` or\n        `'kronecker_factored'`.\n    \"\"\"\n    # pyformat: enable\n    rtl_lib.verify_hyperparameters(\n        lattice_size=lattice_size,\n        output_min=output_min,\n        output_max=output_max,\n        interpolation=interpolation,\n        parameterization=parameterization,\n        kernel_initializer=kernel_initializer,\n        kernel_regularizer=kernel_regularizer)\n    super(RTL, self).__init__(**kwargs)\n    self.num_lattices = num_lattices\n    self.lattice_rank = lattice_rank\n    self.lattice_size = lattice_size\n    self.output_min = output_min\n    self.output_max = output_max\n    self.init_min = init_min\n    self.init_max = init_max\n    self.separate_outputs = separate_outputs\n    self.random_seed = random_seed\n    self.num_projection_iterations = num_projection_iterations\n    self.monotonic_at_every_step = monotonic_at_every_step\n    self.clip_inputs = clip_inputs\n    self.interpolation = interpolation\n    self.parameterization = parameterization\n    self.num_terms = num_terms\n    self.avoid_intragroup_interaction = avoid_intragroup_interaction\n    self.kernel_initializer = kernel_initializer\n    self.kernel_regularizer = kernel_regularizer\n    self.average_outputs = average_outputs\n\n  def build(self, input_shape):\n    \"\"\"Standard Keras build() method.\"\"\"\n    rtl_lib.verify_hyperparameters(\n        lattice_size=self.lattice_size, input_shape=input_shape)\n    # Convert kernel regularizers to proper form (tuples).\n    kernel_regularizer = self.kernel_regularizer\n    if isinstance(self.kernel_regularizer, list):\n      if isinstance(self.kernel_regularizer[0], six.string_types):\n        kernel_regularizer = tuple(self.kernel_regularizer)\n      else:\n        kernel_regularizer = [tuple(r) for r in self.kernel_regularizer]\n    self._rtl_structure = self._get_rtl_structure(input_shape)\n    # dict from monotonicities to the lattice layers with those monotonicities.\n    self._lattice_layers = {}\n    for monotonicities, inputs_for_units in self._rtl_structure:\n      monotonicities_str = ''.join(\n          [str(monotonicity) for monotonicity in monotonicities])\n      # Passthrough names for reconstructing model graph.\n      inputs_for_units_name = '{}_{}'.format(INPUTS_FOR_UNITS_PREFIX,\n                                             monotonicities_str)\n      # Use control dependencies to save inputs_for_units as graph constant for\n      # visualisation toolbox to be able to recover it from saved graph.\n      # Wrap this constant into pure op since in TF 2.0 there are issues passing\n      # tensors into control_dependencies.\n      with tf.control_dependencies([\n          tf.constant(\n              inputs_for_units, dtype=tf.int32, name=inputs_for_units_name)\n      ]):\n        units = len(inputs_for_units)\n        if self.parameterization == 'all_vertices':\n          layer_name = '{}_{}'.format(RTL_LATTICE_NAME, monotonicities_str)\n          lattice_sizes = [self.lattice_size] * self.lattice_rank\n          kernel_initializer = lattice_layer.create_kernel_initializer(\n              kernel_initializer_id=self.kernel_initializer,\n              lattice_sizes=lattice_sizes,\n              monotonicities=monotonicities,\n              output_min=self.output_min,\n              output_max=self.output_max,\n              unimodalities=None,\n              joint_unimodalities=None,\n              init_min=self.init_min,\n              init_max=self.init_max)\n          self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice(\n              lattice_sizes=lattice_sizes,\n              units=units,\n              monotonicities=monotonicities,\n              output_min=self.output_min,\n              output_max=self.output_max,\n              num_projection_iterations=self.num_projection_iterations,\n              monotonic_at_every_step=self.monotonic_at_every_step,\n              clip_inputs=self.clip_inputs,\n              interpolation=self.interpolation,\n              kernel_initializer=kernel_initializer,\n              kernel_regularizer=kernel_regularizer,\n              name=layer_name,\n          )\n        elif self.parameterization == 'kronecker_factored':\n          layer_name = '{}_{}'.format(RTL_KFL_NAME, monotonicities_str)\n          kernel_initializer = kfll.create_kernel_initializer(\n              kernel_initializer_id=self.kernel_initializer,\n              monotonicities=monotonicities,\n              output_min=self.output_min,\n              output_max=self.output_max,\n              init_min=self.init_min,\n              init_max=self.init_max)\n          self._lattice_layers[str(\n              monotonicities)] = kfll.KroneckerFactoredLattice(\n                  lattice_sizes=self.lattice_size,\n                  units=units,\n                  num_terms=self.num_terms,\n                  monotonicities=monotonicities,\n                  output_min=self.output_min,\n                  output_max=self.output_max,\n                  clip_inputs=self.clip_inputs,\n                  kernel_initializer=kernel_initializer,\n                  scale_initializer='scale_initializer',\n                  name=layer_name)\n        else:\n          raise ValueError('Unknown type of parameterization: {}'.format(\n              self.parameterization))\n    super(RTL, self).build(input_shape)\n\n  def call(self, x, **kwargs):\n    \"\"\"Standard Keras call() method.\"\"\"\n    if not isinstance(x, dict):\n      x = {'unconstrained': x}\n\n    # Flatten the input.\n    # The order for flattening should match the order in _get_rtl_structure.\n    input_tensors = []\n    for input_key in sorted(x.keys()):\n      items = x[input_key]\n      if isinstance(items, list):\n        input_tensors.extend(items)\n      else:\n        input_tensors.append(items)\n    if len(input_tensors) == 1:\n      flattened_input = input_tensors[0]\n    else:\n      flattened_input = tf.concat(input_tensors, axis=1, name=RTL_CONCAT_NAME)\n\n    # outputs_for_monotonicity[0] are non-monotonic outputs\n    # outputs_for_monotonicity[1] are monotonic outputs\n    outputs_for_monotonicity = [[], []]\n    for monotonicities, inputs_for_units in self._rtl_structure:\n      if len(inputs_for_units) == 1:\n        inputs_for_units = inputs_for_units[0]\n      lattice_inputs = tf.gather(flattened_input, inputs_for_units, axis=1)\n      output_monotonicity = max(monotonicities)\n      # Call each lattice layer and store based on output monotonicy.\n      outputs_for_monotonicity[output_monotonicity].append(\n          self._lattice_layers[str(monotonicities)](lattice_inputs))\n\n    if self.separate_outputs:\n      separate_outputs = {}\n      for monotoncity, output_key in [(0, 'unconstrained'), (1, 'increasing')]:\n        lattice_outputs = outputs_for_monotonicity[monotoncity]\n        if not lattice_outputs:\n          # Do not need to add empty list to the output.\n          pass\n        elif len(lattice_outputs) == 1:\n          separate_outputs[output_key] = lattice_outputs[0]\n        else:\n          separate_outputs[output_key] = tf.concat(lattice_outputs, axis=1)\n      return separate_outputs\n    else:\n      joint_outputs = outputs_for_monotonicity[0] + outputs_for_monotonicity[1]\n      if len(joint_outputs) > 1:\n        joint_outputs = tf.concat(joint_outputs, axis=1)\n      else:\n        joint_outputs = joint_outputs[0]\n      if self.average_outputs:\n        joint_outputs = tf.reduce_mean(joint_outputs, axis=-1, keepdims=True)\n      return joint_outputs\n\n  def compute_output_shape(self, input_shape):\n    \"\"\"Standard Keras compute_output_shape() method.\"\"\"\n    if isinstance(input_shape, dict):\n      batch_size = list(input_shape.values())[0][0]\n    else:\n      batch_size = input_shape[0]\n    if not self.separate_outputs:\n      if self.average_outputs:\n        return (batch_size, 1)\n      else:\n        return (batch_size, self.num_lattices)\n    num_outputs = [0, 0]\n    for monotonicities, inputs_for_units in self._rtl_structure:\n      output_monotonicity = max(monotonicities)\n      num_outputs[output_monotonicity] += len(inputs_for_units)\n    output_shape = {}\n    if num_outputs[0]:\n      output_shape['unconstrained'] = (batch_size, num_outputs[0])\n    if num_outputs[1]:\n      output_shape['increasing'] = (batch_size, num_outputs[1])\n    return output_shape\n\n  def get_config(self):\n    \"\"\"Standard Keras get_config() method.\"\"\"\n    config = super(RTL, self).get_config()\n    config.update({\n        'num_lattices': self.num_lattices,\n        'lattice_rank': self.lattice_rank,\n        'lattice_size': self.lattice_size,\n        'output_min': self.output_min,\n        'output_max': self.output_max,\n        'init_min': self.init_min,\n        'init_max': self.init_max,\n        'separate_outputs': self.separate_outputs,\n        'random_seed': self.random_seed,\n        'num_projection_iterations': self.num_projection_iterations,\n        'monotonic_at_every_step': self.monotonic_at_every_step,\n        'clip_inputs': self.clip_inputs,\n        'interpolation': self.interpolation,\n        'parameterization': self.parameterization,\n        'num_terms': self.num_terms,\n        'avoid_intragroup_interaction': self.avoid_intragroup_interaction,\n        'kernel_initializer': self.kernel_initializer,\n        'kernel_regularizer': self.kernel_regularizer,\n        'average_outputs': self.average_outputs,\n    })\n    return config\n\n  def finalize_constraints(self):\n    \"\"\"Ensures layers weights strictly satisfy constraints.\n\n    Applies approximate projection to strictly satisfy specified constraints.\n    If `monotonic_at_every_step == True` there is no need to call this function.\n\n    Returns:\n      In eager mode directly updates weights and returns variable which stores\n      them. In graph mode returns a list of `assign_add` op which has to be\n      executed to updates weights.\n    \"\"\"\n    return list(lattice_layer.finalize_constraints()\n                for lattice_layer in self._lattice_layers.values())\n\n  def assert_constraints(self, eps=1e-6):\n    \"\"\"Asserts that weights satisfy all constraints.\n\n    In graph mode builds and returns a list of assertion ops.\n    In eager mode directly executes assertions.\n\n    Args:\n      eps: allowed constraints violation.\n\n    Returns:\n      List of assertion ops in graph mode or immediately asserts in eager mode.\n    \"\"\"\n    assertions = []\n    for layer in self._lattice_layers.values():\n      assertions.extend(layer.assert_constraints(eps))\n    return assertions\n\n  def _get_rtl_structure(self, input_shape):\n    \"\"\"Returns the RTL structure for the given input_shape.\n\n    Args:\n      input_shape: Input shape to the layer. Must be a dict matching the format\n        described in the layer description.\n\n    Raises:\n      ValueError: If the structure is too small to include all the inputs.\n\n    Returns:\n      A list of `(monotonicities, lattices)` tuples, where `monotonicities` is\n      the tuple of lattice monotonicites, and `lattices` is a list of list of\n      indices into the flattened input to the layer.\n    \"\"\"\n    if not isinstance(input_shape, dict):\n      input_shape = {'unconstrained': input_shape}\n\n    # Calculate the flattened input to the RTL layer. rtl_inputs will be a list\n    # of _RTLInput items, each including information about the monotonicity,\n    # input group and input index for each input to the layer.\n    # The order for flattening should match the order in the call method.\n    rtl_inputs = []\n    group = 0  # group id for the input\n    input_index = 0  # index into the flattened input\n    for input_key in sorted(input_shape.keys()):\n      shapes = input_shape[input_key]\n      if input_key == 'unconstrained':\n        monotonicity = 0\n      elif input_key == 'increasing':\n        monotonicity = 1\n      else:\n        raise ValueError(\n            'Unrecognized key in the input to the RTL layer: {}'.format(\n                input_key))\n\n      if not isinstance(shapes, list):\n        # Get the shape after a split. See single dense tensor input format in\n        # the layer comments.\n        shapes = [(shapes[0], 1)] * shapes[1]\n\n      for shape in shapes:\n        for _ in range(shape[1]):\n          rtl_inputs.append(\n              _RTLInput(\n                  monotonicity=monotonicity,\n                  group=group,\n                  input_index=input_index))\n          input_index += 1\n        group += 1\n\n    total_usage = self.num_lattices * self.lattice_rank\n    if total_usage < len(rtl_inputs):\n      raise ValueError(\n          'RTL layer with {}x{}D lattices is too small to use all the {} input '\n          'features'.format(self.num_lattices, self.lattice_rank,\n                            len(rtl_inputs)))\n\n    # Repeat the features to fill all the slots in the RTL layer.\n    rs = np.random.RandomState(self.random_seed)\n    rs.shuffle(rtl_inputs)\n    rtl_inputs = rtl_inputs * (1 + total_usage // len(rtl_inputs))\n    rtl_inputs = rtl_inputs[:total_usage]\n    rs.shuffle(rtl_inputs)\n\n    # Start with random lattices, possibly with repeated groups in lattices.\n    lattices = []\n    for lattice_index in range(self.num_lattices):\n      lattices.append(\n          rtl_inputs[lattice_index * self.lattice_rank:(lattice_index + 1) *\n                     self.lattice_rank])\n\n    # Swap features between lattices to make sure only a single input from each\n    # group is used in each lattice.\n    changed = True\n    iteration = 0\n    while changed and self.avoid_intragroup_interaction:\n      if iteration > _MAX_RTL_SWAPS:\n        logging.info('Some lattices in the RTL layer might use features from '\n                     'the same input group')\n        break\n      changed = False\n      iteration += 1\n      for lattice_0, lattice_1 in itertools.combinations(lattices, 2):\n        # For every pair of lattices: lattice_0, lattice_1\n        for index_0, index_1 in itertools.product(\n            range(len(lattice_0)), range(len(lattice_1))):\n          # Consider swapping lattice_0[index_0] with lattice_1[index_1]\n          rest_lattice_0 = list(lattice_0)\n          rest_lattice_1 = list(lattice_1)\n          feature_0 = rest_lattice_0.pop(index_0)\n          feature_1 = rest_lattice_1.pop(index_1)\n          if feature_0.group == feature_1.group:\n            continue\n\n          # Swap if a group is repeated and a swap fixes it.\n          rest_lattice_groups_0 = list(\n              lattice_input.group for lattice_input in rest_lattice_0)\n          rest_lattice_groups_1 = list(\n              lattice_input.group for lattice_input in rest_lattice_1)\n          if ((feature_0.group in rest_lattice_groups_0) and\n              (feature_0.group not in rest_lattice_groups_1) and\n              (feature_1.group not in rest_lattice_groups_0)):\n            lattice_0[index_0], lattice_1[index_1] = (lattice_1[index_1],\n                                                      lattice_0[index_0])\n            changed = True\n\n    # Arrange into combined lattices layers. Lattices with similar monotonicites\n    # can use the same tfl.layers.Lattice layer.\n    # Create a dict: monotonicity -> list of list of input indices.\n    lattices_for_monotonicities = collections.defaultdict(list)\n    for lattice in lattices:\n      lattice.sort(key=lambda lattice_input: lattice_input.monotonicity)\n      monotonicities = tuple(\n          lattice_input.monotonicity for lattice_input in lattice)\n      lattice_input_indices = list(\n          lattice_input.input_index for lattice_input in lattice)\n      lattices_for_monotonicities[monotonicities].append(lattice_input_indices)\n\n    return sorted(lattices_for_monotonicities.items())\n"
  },
  {
    "path": "tensorflow_lattice/python/rtl_lib.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Implementation of algorithms required for RTL layer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\n\n\ndef verify_hyperparameters(lattice_size,\n                           input_shape=None,\n                           output_min=None,\n                           output_max=None,\n                           interpolation=\"hypercube\",\n                           parameterization=\"all_vertices\",\n                           kernel_initializer=None,\n                           kernel_regularizer=None):\n  \"\"\"Verifies that all given hyperparameters are consistent.\n\n  See `tfl.layers.RTL` class level comment for detailed description of\n  arguments.\n\n  Args:\n    lattice_size: Lattice size to check againts.\n    input_shape: Shape of layer input.\n    output_min: Minimum output of `RTL` layer.\n    output_max: Maximum output of `RTL` layer.\n    interpolation: One of 'simplex' or 'hypercube' interpolation.\n    parameterization: One of 'all_vertices' or 'kronecker_factored'\n      parameterizations.\n    kernel_initializer: Initizlier to check against.\n    kernel_regularizer: Regularizers to check against.\n\n  Raises:\n    ValueError: If lattice_size < 2.\n    KeyError: If input_shape is a dict with incorrect keys.\n    ValueError: If output_min >= output_max.\n    ValueError: If interpolation is not one of 'simplex' or 'hypercube'.\n    ValueError: If parameterization is 'kronecker_factored' and\n      kernel_initializer is 'linear_initializer'.\n    ValueError: If parameterization is 'kronecker_factored' and\n      kernel_regularizer is not None.\n    ValueError: If kernel_regularizer contains a tuple with len != 3.\n    ValueError: If kernel_regularizer contains a tuple with non-float l1 value.\n    ValueError: If kernel_regularizer contains a tuple with non-flaot l2 value.\n\n  \"\"\"\n  if lattice_size < 2:\n    raise ValueError(\n        \"Lattice size must be at least 2. Given: {}\".format(lattice_size))\n\n  if input_shape:\n    if isinstance(input_shape, dict):\n      for key in input_shape:\n        if key not in [\"unconstrained\", \"increasing\"]:\n          raise KeyError(\"Input shape keys should be either 'unconstrained' \"\n                         \"or 'increasing', but seeing: {}\".format(key))\n\n  if output_min is not None and output_max is not None:\n    if output_min >= output_max:\n      raise ValueError(\"'output_min' must be not greater than 'output_max'. \"\n                       \"'output_min': %f, 'output_max': %f\" %\n                       (output_min, output_max))\n\n  if interpolation not in [\"hypercube\", \"simplex\"]:\n    raise ValueError(\"RTL interpolation type should be either 'simplex' \"\n                     \"or 'hypercube': %s\" % interpolation)\n\n  if (parameterization == \"kronecker_factored\" and\n      kernel_initializer == \"linear_initializer\"):\n    raise ValueError(\"'kronecker_factored' parameterization does not currently \"\n                     \"support linear iniitalization. 'parameterization': %s, \"\n                     \"'kernel_initializer': %s\" %\n                     (parameterization, kernel_initializer))\n\n  if (parameterization == \"kronecker_factored\" and\n      kernel_regularizer is not None):\n    raise ValueError(\"'kronecker_factored' parameterization does not currently \"\n                     \"support regularization. 'parameterization': %s, \"\n                     \"'kernel_regularizer': %s\" %\n                     (parameterization, kernel_regularizer))\n\n  if kernel_regularizer:\n    if isinstance(kernel_regularizer, list):\n      regularizers = kernel_regularizer\n      if isinstance(kernel_regularizer[0], six.string_types):\n        regularizers = [kernel_regularizer]\n      for regularizer in regularizers:\n        if len(regularizer) != 3:\n          raise ValueError(\"Regularizer tuples/lists must have three elements \"\n                           \"(type, l1, and l2). Given: {}\".format(regularizer))\n        _, l1, l2 = regularizer\n        if not isinstance(l1, float):\n          raise ValueError(\n              \"Regularizer l1 must be a single float. Given: {}\".format(\n                  type(l1)))\n        if not isinstance(l2, float):\n          raise ValueError(\n              \"Regularizer l2 must be a single float. Given: {}\".format(\n                  type(l2)))\n"
  },
  {
    "path": "tensorflow_lattice/python/rtl_test.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Lattice Layer.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tempfile\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_lattice.python import linear_layer\nfrom tensorflow_lattice.python import pwl_calibration_layer\nfrom tensorflow_lattice.python import rtl_layer\n# pylint: disable=g-import-not-at-top\n# Use Keras 2.\nversion_fn = getattr(tf.keras, \"version\", None)\nif version_fn and version_fn().startswith(\"3.\"):\n  import tf_keras as keras\nelse:\n  keras = tf.keras\n\n\nclass RTLTest(parameterized.TestCase, tf.test.TestCase):\n\n  def setUp(self):\n    super(RTLTest, self).setUp()\n    self.disable_all = False\n    keras.utils.set_random_seed(42)\n\n  def testRTLInputShapes(self):\n    if self.disable_all:\n      return\n    data_size = 100\n\n    # Dense input format.\n    a = np.random.random_sample(size=(data_size, 10))\n    b = np.random.random_sample(size=(data_size, 20))\n    target_ab = (\n        np.max(a, axis=1, keepdims=True) + np.min(b, axis=1, keepdims=True))\n\n    input_a = keras.layers.Input(shape=(10,))\n    input_b = keras.layers.Input(shape=(20,))\n\n    rtl_0 = rtl_layer.RTL(num_lattices=6, lattice_rank=5)\n    rtl_outputs = rtl_0({\"unconstrained\": input_a, \"increasing\": input_b})\n    outputs = keras.layers.Dense(1)(rtl_outputs)\n    model = keras.Model(inputs=[input_a, input_b], outputs=outputs)\n    model.compile(loss=\"mse\")\n    model.fit([a, b], target_ab)\n    model.predict([a, b])\n\n    # Inputs to be calibrated.\n    c = np.random.random_sample(size=(data_size, 1))\n    d = np.random.random_sample(size=(data_size, 1))\n    e = np.random.random_sample(size=(data_size, 1))\n    f = np.random.random_sample(size=(data_size, 1))\n    target_cdef = np.sin(np.pi * c) * np.cos(np.pi * d) - e * f\n\n    input_c = keras.layers.Input(shape=(1,))\n    input_d = keras.layers.Input(shape=(1,))\n    input_e = keras.layers.Input(shape=(1,))\n    input_f = keras.layers.Input(shape=(1,))\n\n    input_keypoints = np.linspace(0.0, 1.0, 10)\n    calib_c = pwl_calibration_layer.PWLCalibration(\n        units=2,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0)(\n            input_c)\n    calib_d = pwl_calibration_layer.PWLCalibration(\n        units=3,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0)(\n            input_d)\n    calib_e = pwl_calibration_layer.PWLCalibration(\n        units=4,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0,\n        monotonicity=\"decreasing\")(\n            input_e)\n    calib_f = pwl_calibration_layer.PWLCalibration(\n        units=5,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0,\n        monotonicity=\"decreasing\")(\n            input_f)\n\n    rtl_0 = rtl_layer.RTL(num_lattices=10, lattice_rank=3)\n    rtl_0_outputs = rtl_0({\n        \"unconstrained\": [calib_c, calib_d],\n        \"increasing\": [calib_e, calib_f]\n    })\n    outputs = linear_layer.Linear(\n        num_input_dims=10, monotonicities=[1] * 10)(\n            rtl_0_outputs)\n    model = keras.Model(\n        inputs=[input_c, input_d, input_e, input_f], outputs=outputs\n    )\n    model.compile(loss=\"mse\")\n    model.fit([c, d, e, f], target_cdef)\n    model.predict([c, d, e, f])\n\n    # Two layer RTL model.\n    rtl_0 = rtl_layer.RTL(\n        num_lattices=10,\n        lattice_rank=3,\n        output_min=0.0,\n        output_max=1.0,\n        separate_outputs=True)\n    rtl_0_outputs = rtl_0({\n        \"unconstrained\": [calib_c, calib_d],\n        \"increasing\": [calib_e, calib_f]\n    })\n    rtl_1 = rtl_layer.RTL(num_lattices=3, lattice_rank=4)\n    rtl_1_outputs = rtl_1(rtl_0_outputs)\n    outputs = linear_layer.Linear(\n        num_input_dims=3, monotonicities=[1] * 3)(\n            rtl_1_outputs)\n    model = keras.Model(\n        inputs=[input_c, input_d, input_e, input_f], outputs=outputs\n    )\n    model.compile(loss=\"mse\")\n    model.fit([c, d, e, f], target_cdef)\n    model.predict([c, d, e, f])\n\n  def testRTLOutputShape(self):\n    if self.disable_all:\n      return\n\n    # Multiple Outputs Per Lattice\n    input_shape, output_shape = (30,), (None, 6)\n    input_a = keras.layers.Input(shape=input_shape)\n    rtl_0 = rtl_layer.RTL(num_lattices=6, lattice_rank=5)\n    output = rtl_0(input_a)\n    self.assertAllEqual(output_shape, rtl_0.compute_output_shape(input_a.shape))\n    self.assertAllEqual(output_shape, output.shape)\n\n    # Average Outputs\n    output_shape = (None, 1)\n    rtl_1 = rtl_layer.RTL(num_lattices=6, lattice_rank=5, average_outputs=True)\n    output = rtl_1(input_a)\n    self.assertAllEqual(output_shape, rtl_1.compute_output_shape(input_a.shape))\n    self.assertAllEqual(output_shape, output.shape)\n\n  def testRTLSaveLoad(self):\n    if self.disable_all:\n      return\n\n    input_c = keras.layers.Input(shape=(1,))\n    input_d = keras.layers.Input(shape=(1,))\n    input_e = keras.layers.Input(shape=(1,))\n    input_f = keras.layers.Input(shape=(1,))\n\n    input_keypoints = np.linspace(0.0, 1.0, 10)\n    calib_c = pwl_calibration_layer.PWLCalibration(\n        units=2,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0)(\n            input_c)\n    calib_d = pwl_calibration_layer.PWLCalibration(\n        units=3,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0)(\n            input_d)\n    calib_e = pwl_calibration_layer.PWLCalibration(\n        units=4,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0,\n        monotonicity=\"decreasing\")(\n            input_e)\n    calib_f = pwl_calibration_layer.PWLCalibration(\n        units=5,\n        input_keypoints=input_keypoints,\n        output_min=0.0,\n        output_max=1.0,\n        monotonicity=\"decreasing\")(\n            input_f)\n\n    rtl_0 = rtl_layer.RTL(\n        num_lattices=10,\n        lattice_rank=3,\n        output_min=0.0,\n        output_max=1.0,\n        separate_outputs=True)\n    rtl_0_outputs = rtl_0({\n        \"unconstrained\": [calib_c, calib_d],\n        \"increasing\": [calib_e, calib_f]\n    })\n    rtl_1 = rtl_layer.RTL(num_lattices=3, lattice_rank=4)\n    rtl_1_outputs = rtl_1(rtl_0_outputs)\n    outputs = linear_layer.Linear(\n        num_input_dims=3, monotonicities=[1] * 3)(\n            rtl_1_outputs)\n    model = keras.Model(\n        inputs=[input_c, input_d, input_e, input_f], outputs=outputs\n    )\n    model.compile(loss=\"mse\")\n    model.use_legacy_config = True\n\n    with tempfile.NamedTemporaryFile(suffix=\".h5\") as f:\n      model.save(f.name)\n      _ = keras.models.load_model(\n          f.name,\n          custom_objects={\n              \"RTL\": rtl_layer.RTL,\n              \"PWLCalibration\": pwl_calibration_layer.PWLCalibration,\n              \"Linear\": linear_layer.Linear,\n          },\n      )\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_lattice/python/test_utils.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Helpers to train simple model for tests and print debug output.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport time\n\nfrom absl import logging\nimport numpy as np\n\n\nclass TimeTracker(object):\n  \"\"\"Tracks time.\n\n  Keeps track of time spent in its scope and appends it to 'list_to_append'\n  on exit from scope divided by 'num_steps' if provided.\n\n  Example:\n    training_step_times = []\n    with TimeTracker(training_step_times, num_steps=num_epochs):\n      model.fit(... epochs=num_epochs ...)\n    print np.median(training_step_times)\n  \"\"\"\n\n  def __init__(self, list_to_append, num_steps=1):\n    self._list_to_append = list_to_append\n    self._num_steps = float(num_steps)\n\n  def __enter__(self):\n    self._start_time = time.time()\n    return self\n\n  def __exit__(self, unuesd_type, unuesd_value, unuesd_traceback):\n    duration = time.time() - self._start_time\n    self._list_to_append.append(\n        duration / self._num_steps if self._num_steps else 0.0)\n\n\ndef run_training_loop(config,\n                      training_data,\n                      keras_model,\n                      input_dtype=np.float32,\n                      label_dtype=np.float32):\n  \"\"\"Trains models and prints debug info.\n\n  Args:\n    config: dictionary of test case parameters. See tests for TensorFlow Lattice\n      layers.\n    training_data: tuple: (training_inputs, labels) where\n      training_inputs and labels are proper data to train models passed via\n      other parameters.\n    keras_model: Keras model to train on training_data.\n    input_dtype: dtype for input conversion.\n    label_dtype: dtype for label conversion.\n\n  Returns:\n    Loss measured on training data and tf.session() if one was initialized\n    explicitly during training.\n  \"\"\"\n  (training_inputs, training_labels) = training_data\n  np_training_inputs = np.asarray(training_inputs).astype(input_dtype)\n  np_training_labels = np.asarray(training_labels).astype(label_dtype)\n\n  logging.info(\" {0: <10}{1: <10}\".format(\"it\", \"Loss\"))\n\n  num_steps = 10\n  training_step_times = []\n  for step in range(num_steps):\n    begin = (config[\"num_training_epoch\"] * step) // num_steps\n    end = (config[\"num_training_epoch\"] * (step + 1)) // num_steps\n    num_epochs = end - begin\n    if num_epochs == 0:\n      continue\n\n    loss = keras_model.evaluate(np_training_inputs, np_training_labels,\n                                batch_size=len(np_training_inputs),\n                                verbose=0)\n    with TimeTracker(training_step_times, num_steps=num_epochs):\n      keras_model.fit(np_training_inputs, np_training_labels,\n                      batch_size=len(np_training_inputs),\n                      epochs=num_epochs,\n                      verbose=0)\n    logging.info(\"{0: <10}{1: <10,.6f}\".format(begin, loss))\n  # End of: 'for step in range(num_steps):'\n\n  loss = keras_model.evaluate(np_training_inputs, np_training_labels,\n                              batch_size=len(np_training_inputs),\n                              verbose=0)\n  logging.info(\"Final loss: %f\", loss)\n\n  if training_step_times:\n    logging.info(\"Median training step time: %f\",\n                 np.median(training_step_times))\n\n  return loss\n\n\ndef two_dim_mesh_grid(num_points, x_min, y_min, x_max, y_max):\n  \"\"\"Generates uniform 2-d mesh grid for 3-d surfaces visualisation via pyplot.\n\n  Uniformly distributes 'num_points' within rectangle:\n  (x_min, y_min) - (x_max, y_max)\n  'num_points' should be such that uniform distribution is possible. In other\n  words there should exist such integers 'x_points' and 'y_points' that:\n  - x_points * y_points == num_points\n  - x_points / y_points == (x_max - x_min) / (y_max - y_min)\n\n  Args:\n    num_points: number of points in the grid.\n    x_min: bounds of the grid.\n    y_min: bounds of the grid.\n    x_max: bounds of the grid.\n    y_max: bounds of the grid.\n\n  Returns:\n    Tuple containing 2 numpy arrays which represent X and Y coordinates of mesh\n    grid\n\n  Raises:\n    ValueError: if it's impossible to uniformly distribute 'num_points' across\n    specified grid.\n\n  \"\"\"\n  x_size = x_max - x_min\n  y_size = y_max - y_min\n  x_points = (num_points * x_size / y_size)**0.5\n  y_points = num_points / x_points\n\n  eps = 1e-7\n  is_int = lambda x: abs(x - int(x + eps)) < eps\n  if not is_int(x_points) or not is_int(y_points):\n    raise ValueError(\"Cannot evenly distribute %d points across sides of \"\n                     \"lengths: %f and %f\" % (num_points, x_size, y_size))\n\n  x_grid = np.linspace(start=x_min, stop=x_max, num=int(x_points + eps))\n  y_grid = np.linspace(start=y_min, stop=y_max, num=int(y_points + eps))\n\n  # Convert list returned by meshgrid() to tuple so we can easily distinguish\n  # mesh grid vs list of points.\n  return tuple(np.meshgrid(x_grid, y_grid))\n\n\ndef sample_uniformly(num_points, lower_bounds, upper_bounds):\n  \"\"\"Deterministically generates num_point random points within bounds.\n\n  Points will be such that:\n  lower_bounds[i] <= p[i] <= upper_bounds[i]\n\n  Number of dimensions is defined by lengths of lower_bounds list.\n\n  Args:\n    num_points: number of points to generate.\n    lower_bounds: list or tuple of lower bounds.\n    upper_bounds: list or tuple of upper bounds.\n\n  Returns:\n    List of generated points.\n  \"\"\"\n  if len(lower_bounds) != len(upper_bounds):\n    raise ValueError(\"Lower and upper bounds must have same length. They are: \"\n                     \"lower_bounds: %s, upper_bounds: %s\" %\n                     (lower_bounds, upper_bounds))\n  np.random.seed(41)\n  x = []\n  for _ in range(num_points):\n    point = [\n        lower + np.random.random() * (upper - lower)\n        for lower, upper in zip(lower_bounds, upper_bounds)\n    ]\n    x.append(np.asarray(point))\n  return x\n\n\ndef get_hypercube_interpolation_fn(coefficients):\n  \"\"\"Returns function which does hypercube interpolation.\n\n  This is only for 2^d lattice aka hypercube.\n\n  Args:\n    coefficients: coefficients of hypercube ordered according to index of\n      corresponding vertex.\n\n  Returns:\n    Function which takes d-dimension point and performs hypercube interpolation\n    with given coefficients.\n  \"\"\"\n\n  def hypercube_interpolation_fn(x):\n    \"\"\"Does hypercube interpolation.\"\"\"\n    if 2**len(x) != len(coefficients):\n      raise ValueError(\"Number of coefficients(%d) does not correspond to \"\n                       \"dimension 'x'(%s)\" % (len(coefficients), x))\n    result = 0.0\n    for coefficient_index in range(len(coefficients)):\n      weight = 1.0\n      for input_dimension in range(len(x)):\n        if coefficient_index & (1 << input_dimension):\n          # If statement checks whether 'input_dimension' bit of\n          # 'coefficient_index' is set to 1.\n          weight *= x[input_dimension]\n        else:\n          weight *= (1.0 - x[input_dimension])\n      result += coefficients[coefficient_index] * weight\n    return result\n\n  return hypercube_interpolation_fn\n\n\ndef get_linear_lattice_interpolation_fn(lattice_sizes, monotonicities,\n                                        output_min, output_max):\n  \"\"\"Returns function which does lattice interpolation.\n\n  Returned function matches lattice_layer.LinearInitializer with corresponding\n  parameters.\n\n  Args:\n    lattice_sizes: list or tuple of integers which represents lattice sizes.\n    monotonicities: monotonicity constraints.\n    output_min: minimum output of linear function.\n    output_max: maximum output of linear function.\n\n  Returns:\n    Function which takes d-dimension point and performs lattice interpolation\n    assuming lattice weights are such that lattice represents linear function\n    with given output_min and output_max. All monotonic dimesions of this linear\n    function cotribute with same weight despite of numer of vertices per\n    dimension. All non monotonic dimensions have weight 0.0.\n  \"\"\"\n\n  def linear_interpolation_fn(x):\n    \"\"\"Linear along monotonic dims and 0.0 along non monotonic.\"\"\"\n    result = output_min\n    num_monotonic_dims = len(monotonicities) - monotonicities.count(0)\n    if num_monotonic_dims == 0:\n      local_monotonicities = [1] * len(lattice_sizes)\n      num_monotonic_dims = len(lattice_sizes)\n    else:\n      local_monotonicities = monotonicities\n\n    weight = (output_max - output_min) / num_monotonic_dims\n    for i in range(len(x)):\n      if local_monotonicities[i]:\n        result += x[i] * weight / (lattice_sizes[i] - 1.0)\n    return result\n\n  return linear_interpolation_fn\n"
  },
  {
    "path": "tensorflow_lattice/python/utils.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Helpers shared by multiple modules in TFL.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\n\n\n# TODO: update library not to explicitly check if None so we can return\n# an empty list instead of None for these canonicalization methods.\ndef canonicalize_convexity(convexity):\n  \"\"\"Converts string constants representing convexity into integers.\n\n  Args:\n    convexity: The convexity hyperparameter of `tfl.layers.PWLCalibration`\n      layer.\n\n  Returns:\n    convexity represented as -1, 0, 1, or None.\n\n  Raises:\n    ValueError: If convexity is not in the set\n      {-1, 0, 1, 'concave', 'none', 'convex'}.\n  \"\"\"\n  if convexity is None:\n    return None\n\n  if convexity in [-1, 0, 1]:\n    return convexity\n  elif isinstance(convexity, six.string_types):\n    if convexity.lower() == \"concave\":\n      return -1\n    if convexity.lower() == \"none\":\n      return 0\n    if convexity.lower() == \"convex\":\n      return 1\n  raise ValueError(\"'convexity' must be from: [-1, 0, 1, 'concave', \"\n                   \"'none', 'convex']. Given: {}\".format(convexity))\n\n\ndef canonicalize_input_bounds(input_bounds):\n  \"\"\"Converts string constant 'none' representing unspecified bound into None.\n\n  Args:\n    input_bounds: The input_min or input_max hyperparameter of\n      `tfl.layers.Linear` layer.\n\n  Returns:\n    A list of [val, val, ...] where val can be a float or None, or the value\n    None if input_bounds is None.\n\n  Raises:\n    ValueError: If one of elements in input_bounds is not a float, None or\n      'none'.\n  \"\"\"\n  if input_bounds:\n    canonicalized = []\n    for item in input_bounds:\n      if isinstance(item, float) or item is None:\n        canonicalized.append(item)\n      elif isinstance(item, six.string_types) and item.lower() == \"none\":\n        canonicalized.append(None)\n      else:\n        raise ValueError(\"Both 'input_min' and 'input_max' elements must be \"\n                         \"either int, float, None, or 'none'. Given: {}\".format(\n                             input_bounds))\n    return canonicalized\n  return None\n\n\ndef canonicalize_monotonicity(monotonicity, allow_decreasing=True):\n  \"\"\"Converts string constants representing monotonicity into integers.\n\n  Args:\n    monotonicity: The monotonicities hyperparameter of a `tfl.layers` Layer\n      (e.g. `tfl.layers.PWLCalibration`).\n    allow_decreasing: If decreasing monotonicity is considered a valid\n      monotonicity.\n\n  Returns:\n    monotonicity represented as -1, 0, 1, or None.\n\n  Raises:\n    ValueError: If monotonicity is not in the set\n      {-1, 0, 1, 'decreasing', 'none', 'increasing'} and allow_decreasing is\n      True.\n    ValueError: If monotonicity is not in the set {0, 1, 'none', 'increasing'}\n      and allow_decreasing is False.\n  \"\"\"\n  if monotonicity is None:\n    return None\n\n  if monotonicity in [-1, 0, 1]:\n    if not allow_decreasing and monotonicity == -1:\n      raise ValueError(\n          \"'monotonicities' must be from: [0, 1, 'none', 'increasing']. \"\n          \"Given: {}\".format(monotonicity))\n    return monotonicity\n  elif isinstance(monotonicity, six.string_types):\n    if monotonicity.lower() == \"decreasing\":\n      if not allow_decreasing:\n        raise ValueError(\n            \"'monotonicities' must be from: [0, 1, 'none', 'increasing']. \"\n            \"Given: {}\".format(monotonicity))\n      return -1\n    if monotonicity.lower() == \"none\":\n      return 0\n    if monotonicity.lower() == \"increasing\":\n      return 1\n  raise ValueError(\"'monotonicities' must be from: [-1, 0, 1, 'decreasing', \"\n                   \"'none', 'increasing']. Given: {}\".format(monotonicity))\n\n\ndef canonicalize_monotonicities(monotonicities, allow_decreasing=True):\n  \"\"\"Converts string constants representing monotonicities into integers.\n\n  Args:\n    monotonicities: monotonicities hyperparameter of a `tfl.layers` Layer (e.g.\n      `tfl.layers.Lattice`).\n    allow_decreasing: If decreasing monotonicity is considered a valid\n      monotonicity.\n\n  Returns:\n    A list of monotonicities represented as -1, 0, 1, or the value None\n    if monotonicities is None.\n\n  Raises:\n    ValueError: If one of monotonicities is not in the set\n      {-1, 0, 1, 'decreasing', 'none', 'increasing'} and allow_decreasing is\n      True.\n    ValueError: If one of monotonicities is not in the set\n      {0, 1, 'none', 'increasing'} and allow_decreasing is False.\n  \"\"\"\n  if monotonicities:\n    return [\n        canonicalize_monotonicity(\n            monotonicity, allow_decreasing=allow_decreasing)\n        for monotonicity in monotonicities\n    ]\n  return None\n\n\ndef canonicalize_trust(trusts):\n  \"\"\"Converts string constants representing trust direction into integers.\n\n  Args:\n    trusts: edgeworth_trusts or trapezoid_trusts hyperparameter of\n      `tfl.layers.Lattice` layer.\n\n  Returns:\n    A list of trust constraint tuples of the form\n    (feature_a, feature_b, direction) where direction can be -1 or 1, or the\n    value None if trusts is None.\n\n  Raises:\n    ValueError: If one of trust constraints does not have 3 elements.\n    ValueError: If one of trust constraints' direction is not in the set\n      {-1, 1, 'negative', 'positive'}.\n  \"\"\"\n  if trusts:\n    canonicalized = []\n    for trust in trusts:\n      if len(trust) != 3:\n        raise ValueError(\"Trust constraints must consist of 3 elements. Seeing \"\n                         \"constraint tuple {}\".format(trust))\n      feature_a, feature_b, direction = trust\n      if direction in [-1, 1]:\n        canonicalized.append(trust)\n      elif (isinstance(direction, six.string_types) and\n            direction.lower() == \"negative\"):\n        canonicalized.append((feature_a, feature_b, -1))\n      elif (isinstance(direction, six.string_types) and\n            direction.lower() == \"positive\"):\n        canonicalized.append((feature_a, feature_b, 1))\n      else:\n        raise ValueError(\"trust constraint direction must be from: [-1, 1, \"\n                         \"'negative', 'positive']. Given: {}\".format(direction))\n    return canonicalized\n  return None\n\n\ndef canonicalize_unimodalities(unimodalities):\n  \"\"\"Converts string constants representing unimodalities into integers.\n\n  Args:\n    unimodalities: unimodalities hyperparameter of `tfl.layers.Lattice` layer.\n\n  Returns:\n    A list of unimodalities represented as -1, 0, 1, or the value None if\n    unimodalities is None.\n\n  Raises:\n    ValueError: If one of unimodalities is not in the set\n      {-1, 0, 1, 'peak', 'none', 'valley'}.\n  \"\"\"\n  if not unimodalities:\n    return None\n  canonicalized = []\n  for unimodality in unimodalities:\n    if unimodality in [-1, 0, 1]:\n      canonicalized.append(unimodality)\n    elif isinstance(unimodality,\n                    six.string_types) and unimodality.lower() == \"peak\":\n      canonicalized.append(-1)\n    elif isinstance(unimodality,\n                    six.string_types) and unimodality.lower() == \"none\":\n      canonicalized.append(0)\n    elif isinstance(unimodality,\n                    six.string_types) and unimodality.lower() == \"valley\":\n      canonicalized.append(1)\n    else:\n      raise ValueError(\n          \"'unimodalities' elements must be from: [-1, 0, 1, 'peak', 'none', \"\n          \"'valley']. Given: {}\".format(unimodalities))\n  return canonicalized\n\n\ndef count_non_zeros(*iterables):\n  \"\"\"Returns total number of non 0 elements in given iterables.\n\n  Args:\n    *iterables: Any number of the value None or iterables of numeric values.\n  \"\"\"\n  result = 0\n  for iterable in iterables:\n    if iterable is not None:\n      result += sum(1 for element in iterable if element != 0)\n  return result\n"
  },
  {
    "path": "tensorflow_lattice/python/utils_test.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Tests for Tensorflow Lattice utility functions.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl.testing import parameterized\nimport tensorflow as tf\nfrom tensorflow_lattice.python import utils\n\n\nclass UtilsTest(parameterized.TestCase, tf.test.TestCase):\n\n  @parameterized.parameters((-1, -1), (0, 0), (1, 1), (\"concave\", -1),\n                            (\"none\", 0), (\"convex\", 1))\n  def testCanonicalizeConvexity(self, convexity,\n                                expected_canonicalized_convexity):\n    canonicalized_convexity = utils.canonicalize_convexity(convexity)\n    self.assertEqual(canonicalized_convexity, expected_canonicalized_convexity)\n\n  @parameterized.parameters((-2), (0.5), (3), (\"invalid_convexity\"),\n                            (\"concaves\"), (\"nonw\"), (\"conve\"))\n  def testInvalidConvexity(self, invalid_convexity):\n    error_message = (\n        \"'convexity' must be from: [-1, 0, 1, 'concave', 'none', 'convex']. \"\n        \"Given: {}\").format(invalid_convexity)\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_convexity(invalid_convexity)\n\n  # Note: must use mapping format because otherwise input parameter list is\n  # considered multiple parameters (not just a single list parameter).\n  @parameterized.parameters(\n      {\n          \"input_bounds\": [0.0, -3.0],\n          \"expected_canonicalized_input_bounds\": [0.0, -3.0]\n      }, {\n          \"input_bounds\": [float(\"-inf\"), 0.12345],\n          \"expected_canonicalized_input_bounds\": [float(\"-inf\"), 0.12345]\n      }, {\n          \"input_bounds\": [\"none\", None],\n          \"expected_canonicalized_input_bounds\": [None, None]\n      })\n  def testCanonicalizeInputBounds(self, input_bounds,\n                                  expected_canonicalized_input_bounds):\n    canonicalized_input_bounds = utils.canonicalize_input_bounds(input_bounds)\n    self.assertAllEqual(canonicalized_input_bounds,\n                        expected_canonicalized_input_bounds)\n\n  @parameterized.parameters({\"invalid_input_bounds\": [0, 1.0, 2.0]},\n                            {\"invalid_input_bounds\": [None, \"nonw\"]})\n  def testInvalidInputBounds(self, invalid_input_bounds):\n    error_message = (\n        \"Both 'input_min' and 'input_max' elements must be either int, float, \"\n        \"None, or 'none'. Given: {}\").format(invalid_input_bounds)\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_input_bounds(invalid_input_bounds)\n\n  @parameterized.parameters((-1, -1), (0, 0), (1, 1), (\"decreasing\", -1),\n                            (\"none\", 0), (\"increasing\", 1))\n  def testCanonicalizeMonotonicity(self, monotonicity,\n                                   expected_canonicalized_monotonicity):\n    canonicalized_monotonicity = utils.canonicalize_monotonicity(monotonicity)\n    self.assertEqual(canonicalized_monotonicity,\n                     expected_canonicalized_monotonicity)\n\n  @parameterized.parameters((-2), (0.5), (3), (\"invalid_monotonicity\"),\n                            (\"decrease\"), (\"increase\"))\n  def testInvalidMonotonicity(self, invalid_monotonicity):\n    error_message = (\n        \"'monotonicities' must be from: [-1, 0, 1, 'decreasing', 'none', \"\n        \"'increasing']. Given: {}\").format(invalid_monotonicity)\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_monotonicity(invalid_monotonicity)\n\n  @parameterized.parameters((\"decreasing\"), (-1))\n  def testInvalidDecreasingMonotonicity(self, invalid_monotonicity):\n    error_message = (\n        \"'monotonicities' must be from: [0, 1, 'none', 'increasing']. \"\n        \"Given: {}\").format(invalid_monotonicity)\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_monotonicity(\n          invalid_monotonicity, allow_decreasing=False)\n\n  # Note: since canonicalize_monotonicities calls canonicalize_monotonicity,\n  # the above test for invalidity is sufficient.\n  @parameterized.parameters(([-1, 0, 1], [-1, 0, 1]),\n                            ([\"decreasing\", \"none\", \"increasing\"], [-1, 0, 1]),\n                            ([\"decreasing\", -1], [-1, -1]),\n                            ([\"none\", 0], [0, 0]), ([\"increasing\", 1], [1, 1]))\n  def testCanonicalizeMonotonicities(self, monotonicities,\n                                     expected_canonicalized_monotonicities):\n    canonicalized_monotonicities = utils.canonicalize_monotonicities(\n        monotonicities)\n    self.assertAllEqual(canonicalized_monotonicities,\n                        expected_canonicalized_monotonicities)\n\n  @parameterized.parameters(([(\"a\", \"b\", -1), (\"b\", \"c\", 1)], [(\"a\", \"b\", -1),\n                                                               (\"b\", \"c\", 1)]),\n                            ([(\"a\", \"b\", \"negative\"),\n                              (\"b\", \"c\", \"positive\")], [(\"a\", \"b\", -1),\n                                                        (\"b\", \"c\", 1)]))\n  def testCanonicalizeTrust(self, trusts, expected_canonicalized_trusts):\n    canonicalized_trusts = utils.canonicalize_trust(trusts)\n    self.assertAllEqual(canonicalized_trusts, expected_canonicalized_trusts)\n\n  # Note 1: this test assumes the first trust in the list has the incorrect\n  # direction. A list with a single trust tuple is sufficient.\n  # Note 2: must use mapping format because otherwise input parameter list is\n  # considered multiple parameters (not just a single list parameter).\n  @parameterized.parameters({\"invalid_trusts\": [(\"a\", \"b\", 0)]},\n                            {\"invalid_trusts\": [(\"a\", \"b\", \"negativ\")]})\n  def testInvalidTrustDirection(self, invalid_trusts):\n    error_message = (\n        \"trust constraint direction must be from: [-1, 1, 'negative', \"\n        \"'positive']. Given: {}\").format(invalid_trusts[0][2])\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_trust(invalid_trusts)\n\n  # Note 1: this test assumes the first trust in the list has the incorrect\n  # size. A list with a single trust tuple is sufficient.\n  # Note 2: must use mapping format because otherwise input parameter list is\n  # considered multiple parameters (not just a single list parameter).\n  @parameterized.parameters({\"invalid_trusts\": [(\"a\", 1)]},\n                            {\"invalid_trusts\": [(\"a\", \"b\", -1, 1)]})\n  def testInvalidTrustLength(self, invalid_trusts):\n    error_message = (\n        \"Trust constraints must consist of 3 elements. Seeing constraint \"\n        \"tuple {}\").format(invalid_trusts[0])\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_trust(invalid_trusts)\n\n  @parameterized.parameters(([0, 1, 1, 0], [1, 0], 3),\n                            ([0, 0, 0], [0, 0, 0], 0),\n                            ([-1, 0, 0, 1], [0, 0], 2),\n                            (None, [1, 1, 1, 1, 1], 5))\n  def testCountNonZeros(self, monotonicities, unimodalities,\n                        expected_non_zeros):\n    non_zeros = utils.count_non_zeros(monotonicities, unimodalities)\n    self.assertEqual(non_zeros, expected_non_zeros)\n\n  @parameterized.parameters(\n      ([-1, 0, 1], [-1, 0, 1]), ([\"peak\", \"none\", \"valley\"], [-1, 0, 1]),\n      ([\"peak\", -1], [-1, -1]), ([\"none\", 0], [0, 0]), ([\"valley\", 1], [1, 1]))\n  def testCanonicalizeUnimodalities(self, unimodalities,\n                                    expected_canonicalized_unimodalities):\n    canonicalized_unimodalities = utils.canonicalize_unimodalities(\n        unimodalities)\n    self.assertAllEqual(canonicalized_unimodalities,\n                        expected_canonicalized_unimodalities)\n\n  # Note: must use mapping format because otherwise input parameter list is\n  # considered multiple parameters (not just a single list parameter).\n  @parameterized.parameters({\"invalid_unimodalities\": [\"vally\", 0]},\n                            {\"invalid_unimodalities\": [-1, 0, 2]})\n  def testInvalidUnimoadlities(self, invalid_unimodalities):\n    error_message = (\n        \"'unimodalities' elements must be from: [-1, 0, 1, 'peak', 'none', \"\n        \"'valley']. Given: {}\").format(invalid_unimodalities)\n    with self.assertRaisesWithLiteralMatch(ValueError, error_message):\n      utils.canonicalize_unimodalities(invalid_unimodalities)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  }
]